Skip to content

numpy_dtype_user: Pave way for downstream libs to define user dtypes #12

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion include/pybind11/cast.h
Original file line number Diff line number Diff line change
Expand Up @@ -1882,14 +1882,22 @@ template <typename T> struct move_if_unreferenced<T, enable_if_t<all_of<
>::value>> : std::true_type {};
template <typename T> using move_never = negation<move_common<T>>;

template <typename type, typename SFINAE = void>
struct cast_is_known_safe : public std::false_type {};

template <typename type>
struct cast_is_known_safe<type,
enable_if_t<std::is_base_of<type_caster_generic, make_caster<type>>::value>> : public std::true_type {};

// Detect whether returning a `type` from a cast on type's type_caster is going to result in a
// reference or pointer to a local variable of the type_caster. Basically, only
// non-reference/pointer `type`s and reference/pointers from a type_caster_generic are safe;
// everything else returns a reference/pointer to a local variable.
template <typename type> using cast_is_temporary_value_reference = bool_constant<
(std::is_reference<type>::value || std::is_pointer<type>::value) &&
!std::is_base_of<type_caster_generic, make_caster<type>>::value &&
!std::is_same<intrinsic_t<type>, void>::value
!std::is_same<intrinsic_t<type>, void>::value &&
!cast_is_known_safe<type>::value
>;

// When a value returned from a C++ function is being cast back to Python, we almost always want to
Expand Down
4 changes: 3 additions & 1 deletion include/pybind11/detail/internals.h
Original file line number Diff line number Diff line change
Expand Up @@ -114,13 +114,15 @@ struct internals {
/// Additional type information which does not fit into the PyTypeObject.
/// Changes to this struct also require bumping `PYBIND11_INTERNALS_VERSION`.
struct type_info {
using implicit_conversion_func = PyObject *(*)(PyObject *, PyTypeObject *);

PyTypeObject *type;
const std::type_info *cpptype;
size_t type_size, type_align, holder_size_in_ptrs;
void *(*operator_new)(size_t);
void (*init_instance)(instance *, holder_erased);
void (*dealloc)(value_and_holder &v_h);
std::vector<PyObject *(*)(PyObject *, PyTypeObject *)> implicit_conversions;
std::vector<implicit_conversion_func> implicit_conversions;
std::vector<std::pair<const std::type_info *, void *(*)(void *)>> implicit_casts;
std::vector<bool (*)(PyObject *, void *&)> *direct_conversions;
buffer_info *(*get_buffer)(PyObject *, void *) = nullptr;
Expand Down
2 changes: 1 addition & 1 deletion include/pybind11/eigen.h
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,7 @@ template <typename props> handle eigen_array_cast(typename props::Type const &sr
constexpr ssize_t elem_size = sizeof(typename props::Scalar);
array a;
using Scalar = typename props::Type::Scalar;
bool is_pyobject = static_cast<pybind11::detail::npy_api::constants>(npy_format_descriptor<Scalar>::value) == npy_api::NPY_OBJECT_;
bool is_pyobject = is_pyobject_dtype<Scalar>::value;

if (!is_pyobject) {
if (props::vector)
Expand Down
250 changes: 223 additions & 27 deletions include/pybind11/numpy.h
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,105 @@ struct PyVoidScalarObject_Proxy {
PyObject *base;
};

// UFuncs.
using npy_intp = Py_intptr_t;

typedef void (*PyUFuncGenericFunction)(
char **args, npy_intp *dimensions, npy_intp *strides, void *innerloopdata);

typedef void (PyArray_VectorUnaryFunc)(
void* from_, void* to_, npy_intp n, void* fromarr, void* toarr);

typedef struct {
PyObject_HEAD
int nin;
int nout;
int nargs;
int identity;
PyUFuncGenericFunction *functions;
void **data;
int ntypes;
int reserved1;
const char *name;
char *types;
const char *doc;
void *ptr;
PyObject *obj;
PyObject *userloops;
uint32_t *op_flags;
uint32_t *iter_flags;
} PyUFuncObject;

// Manually defined :(
constexpr int NPY_NTYPES_ABI_COMPATIBLE = 21;
constexpr int NPY_NSORTS = 3;

// TODO(eric.cousineau): Fill this out as needed for type safety.
// TODO(eric.cousineau): Do not define these if NPY headers are present (for debugging).
using PyArray_GetItemFunc = void;
using PyArray_SetItemFunc = void;
using PyArray_CopySwapNFunc = void;
using PyArray_CopySwapFunc = void;
using PyArray_CompareFunc = void;
using PyArray_ArgFunc = void;
using PyArray_DotFunc = void;
using PyArray_ScanFunc = void;
using PyArray_FromStrFunc = void;
using PyArray_NonzeroFunc = void;
using PyArray_FillFunc = void;
using PyArray_FillWithScalarFunc = void;
using PyArray_SortFunc = void;
using PyArray_ArgSortFunc = void;
using PyArray_ScalarKindFunc = void;
using PyArray_FastClipFunc = void;
using PyArray_FastPutmaskFunc = void;
using PyArray_FastTakeFunc = void;
using PyArray_ArgFunc = void;

typedef struct {
PyArray_VectorUnaryFunc *cast[NPY_NTYPES_ABI_COMPATIBLE];
PyArray_GetItemFunc *getitem;
PyArray_SetItemFunc *setitem;
PyArray_CopySwapNFunc *copyswapn;
PyArray_CopySwapFunc *copyswap;
PyArray_CompareFunc *compare;
PyArray_ArgFunc *argmax;
PyArray_DotFunc *dotfunc;
PyArray_ScanFunc *scanfunc;
PyArray_FromStrFunc *fromstr;
PyArray_NonzeroFunc *nonzero;
PyArray_FillFunc *fill;
PyArray_FillWithScalarFunc *fillwithscalar;
PyArray_SortFunc *sort[NPY_NSORTS];
PyArray_ArgSortFunc *argsort[NPY_NSORTS];
PyObject *castdict;
PyArray_ScalarKindFunc *scalarkind;
int **cancastscalarkindto;
int *cancastto;
PyArray_FastClipFunc *fastclip;
PyArray_FastPutmaskFunc *fastputmask;
PyArray_FastTakeFunc *fasttake;
PyArray_ArgFunc *argmin;
} PyArray_ArrFuncs;

using PyArray_ArrayDescr = void;

typedef struct {
PyObject_HEAD
PyTypeObject *typeobj;
char kind;
char type;
char byteorder;
char unused;
int flags;
int type_num;
int elsize;
int alignment;
PyArray_ArrayDescr *subarray;
PyObject *fields;
PyArray_ArrFuncs *f;
} PyArray_Descr;

struct numpy_type_info {
PyObject* dtype_ptr;
std::string format_str;
Expand Down Expand Up @@ -109,14 +208,16 @@ inline numpy_internals& get_numpy_internals() {
}

struct npy_api {
enum constants {
enum constants : int {
// Array properties
NPY_ARRAY_C_CONTIGUOUS_ = 0x0001,
NPY_ARRAY_F_CONTIGUOUS_ = 0x0002,
NPY_ARRAY_OWNDATA_ = 0x0004,
NPY_ARRAY_FORCECAST_ = 0x0010,
NPY_ARRAY_ENSUREARRAY_ = 0x0040,
NPY_ARRAY_ALIGNED_ = 0x0100,
NPY_ARRAY_WRITEABLE_ = 0x0400,
// Dtypes
NPY_BOOL_ = 0,
NPY_BYTE_, NPY_UBYTE_,
NPY_SHORT_, NPY_USHORT_,
Expand All @@ -126,9 +227,27 @@ struct npy_api {
NPY_FLOAT_, NPY_DOUBLE_, NPY_LONGDOUBLE_,
NPY_CFLOAT_, NPY_CDOUBLE_, NPY_CLONGDOUBLE_,
NPY_OBJECT_ = 17,
NPY_STRING_, NPY_UNICODE_, NPY_VOID_
NPY_STRING_, NPY_UNICODE_, NPY_VOID_,
NPY_USERDEF_ = 256,
// Descriptor flags
NPY_NEEDS_INIT_ = 0x08,
NPY_NEEDS_PYAPI_ = 0x10,
NPY_USE_GETITEM_ = 0x20,
NPY_USE_SETITEM_ = 0x40,
// UFunc
PyUFunc_None_ = -1,
};

typedef enum {
NPY_NOSCALAR_ = -1,
NPY_BOOL_SCALAR_,
NPY_INTPOS_SCALAR_,
NPY_INTNEG_SCALAR_,
NPY_FLOAT_SCALAR_,
NPY_COMPLEX_SCALAR_,
NPY_OBJECT_SCALAR_
} NPY_SCALARKIND;

typedef struct {
Py_intptr_t *ptr;
int len;
Expand All @@ -146,6 +265,7 @@ struct npy_api {
return (bool) PyObject_TypeCheck(obj, PyArrayDescr_Type_);
}

// Multiarray.
unsigned int (*PyArray_GetNDArrayCFeatureVersion_)();
PyObject *(*PyArray_DescrFromType_)(int);
PyObject *(*PyArray_NewFromDescr_)
Expand All @@ -166,8 +286,29 @@ struct npy_api {
PyObject *(*PyArray_Squeeze_)(PyObject *);
int (*PyArray_SetBaseObject_)(PyObject *, PyObject *);
PyObject* (*PyArray_Resize_)(PyObject*, PyArray_Dims*, int, int);

// - Dtypes
PyTypeObject* PyGenericArrType_Type_;
int (*PyArray_RegisterDataType_)(PyArray_Descr* dtype);
int (*PyArray_RegisterCastFunc_)(PyArray_Descr* descr, int totype, PyArray_VectorUnaryFunc* castfunc);
int (*PyArray_RegisterCanCast_)(PyArray_Descr* descr, int totype, NPY_SCALARKIND scalar);
void (*PyArray_InitArrFuncs_)(PyArray_ArrFuncs *f);

// UFuncs.
PyObject* (*PyUFunc_FromFuncAndData_)(
PyUFuncGenericFunction* func, void** data, char* types, int ntypes,
int nin, int nout, int identity, char* name, char* doc, int unused);

int (*PyUFunc_RegisterLoopForType_)(
PyUFuncObject* ufunc, int usertype, PyUFuncGenericFunction function, int* arg_types, void* data);

int (*PyUFunc_ReplaceLoopBySignature_)(
PyUFuncObject *func, PyUFuncGenericFunction newfunc,
int *signature, PyUFuncGenericFunction *oldfunc);
private:
// TODO(eric.cousineau): Rename to `items` or something, since this now applies to types.
enum functions {
// multiarray
API_PyArray_GetNDArrayCFeatureVersion = 211,
API_PyArray_Type = 2,
API_PyArrayDescr_Type = 3,
Expand All @@ -184,38 +325,68 @@ struct npy_api {
API_PyArray_EquivTypes = 182,
API_PyArray_GetArrayParamsFromObject = 278,
API_PyArray_Squeeze = 136,
API_PyArray_SetBaseObject = 282
API_PyArray_SetBaseObject = 282,
// - DTypes
API_PyGenericArrType_Type = 10,
API_PyArray_RegisterDataType = 192,
API_PyArray_RegisterCastFunc = 193,
API_PyArray_RegisterCanCast = 194,
API_PyArray_InitArrFuncs = 195,
// umath
API_PyUFunc_FromFuncAndData = 1,
API_PyUFunc_RegisterLoopForType = 2,
API_PyUFunc_ReplaceLoopBySignature = 30,
};

static npy_api lookup() {
module m = module::import("numpy.core.multiarray");
auto c = m.attr("_ARRAY_API");
static void** get_api_ptr(object c) {
#if PY_MAJOR_VERSION >= 3
void **api_ptr = (void **) PyCapsule_GetPointer(c.ptr(), NULL);
return (void **) PyCapsule_GetPointer(c.ptr(), NULL);
#else
void **api_ptr = (void **) PyCObject_AsVoidPtr(c.ptr());
return (void **) PyCObject_AsVoidPtr(c.ptr());
#endif
}

static npy_api lookup() {
npy_api api;
#define DECL_NPY_API(Func) api.Func##_ = (decltype(api.Func##_)) api_ptr[API_##Func];
DECL_NPY_API(PyArray_GetNDArrayCFeatureVersion);
if (api.PyArray_GetNDArrayCFeatureVersion_() < 0x7)
pybind11_fail("pybind11 numpy support requires numpy >= 1.7.0");
DECL_NPY_API(PyArray_Type);
DECL_NPY_API(PyVoidArrType_Type);
DECL_NPY_API(PyArrayDescr_Type);
DECL_NPY_API(PyArray_DescrFromType);
DECL_NPY_API(PyArray_DescrFromScalar);
DECL_NPY_API(PyArray_FromAny);
DECL_NPY_API(PyArray_Resize);
DECL_NPY_API(PyArray_CopyInto);
DECL_NPY_API(PyArray_NewCopy);
DECL_NPY_API(PyArray_NewFromDescr);
DECL_NPY_API(PyArray_DescrNewFromType);
DECL_NPY_API(PyArray_DescrConverter);
DECL_NPY_API(PyArray_EquivTypes);
DECL_NPY_API(PyArray_GetArrayParamsFromObject);
DECL_NPY_API(PyArray_Squeeze);
DECL_NPY_API(PyArray_SetBaseObject);
// multiarray -> _ARRAY_API
{
module multiarray = module::import("numpy.core.multiarray");
auto api_ptr = get_api_ptr(multiarray.attr("_ARRAY_API"));
DECL_NPY_API(PyArray_GetNDArrayCFeatureVersion);
if (api.PyArray_GetNDArrayCFeatureVersion_() < 0x7)
pybind11_fail("pybind11 numpy support requires numpy >= 1.7.0");
DECL_NPY_API(PyArray_Type);
DECL_NPY_API(PyVoidArrType_Type);
DECL_NPY_API(PyArrayDescr_Type);
DECL_NPY_API(PyArray_DescrFromType);
DECL_NPY_API(PyArray_DescrFromScalar);
DECL_NPY_API(PyArray_FromAny);
DECL_NPY_API(PyArray_Resize);
DECL_NPY_API(PyArray_CopyInto);
DECL_NPY_API(PyArray_NewCopy);
DECL_NPY_API(PyArray_NewFromDescr);
DECL_NPY_API(PyArray_DescrNewFromType);
DECL_NPY_API(PyArray_DescrConverter);
DECL_NPY_API(PyArray_EquivTypes);
DECL_NPY_API(PyArray_GetArrayParamsFromObject);
DECL_NPY_API(PyArray_Squeeze);
DECL_NPY_API(PyArray_SetBaseObject);
// - Dtypes
DECL_NPY_API(PyGenericArrType_Type);
DECL_NPY_API(PyArray_RegisterDataType);
DECL_NPY_API(PyArray_InitArrFuncs);
DECL_NPY_API(PyArray_RegisterCastFunc);
DECL_NPY_API(PyArray_RegisterCanCast);
}
// umath -> _UFUNC_API
{
module umath = module::import("numpy.core.umath");
auto api_ptr = get_api_ptr(umath.attr("_UFUNC_API"));
DECL_NPY_API(PyUFunc_FromFuncAndData);
DECL_NPY_API(PyUFunc_RegisterLoopForType);
DECL_NPY_API(PyUFunc_ReplaceLoopBySignature);
}
#undef DECL_NPY_API
return api;
}
Expand Down Expand Up @@ -465,6 +636,11 @@ class dtype : public object {
return detail::array_descriptor_proxy(m_ptr)->kind;
}

/// Type index for builtin or user-registered dtypes.
int num() const {
return detail::array_descriptor_proxy(m_ptr)->type_num;
}

private:
static object _dtype_from_pep3118() {
static PyObject *obj = module::import("numpy.core._internal")
Expand Down Expand Up @@ -1049,6 +1225,26 @@ template<typename T> struct npy_format_descriptor<T, enable_if_t<std::is_enum<T>
static pybind11::dtype dtype() { return base_descr::dtype(); }
};

template <>
struct npy_format_descriptor<object> {
static pybind11::dtype dtype() {
if (auto ptr = npy_api::get().PyArray_DescrFromType_(npy_api::NPY_OBJECT_))
return reinterpret_borrow<pybind11::dtype>(ptr);
pybind11_fail("Unsupported buffer format!");
}
};

template <>
struct npy_format_descriptor<void> {
static constexpr auto name = detail::_<void>();
static pybind11::dtype dtype() {
if (auto ptr = detail::npy_api::get().PyArray_DescrFromType_(
detail::npy_api::constants::NPY_VOID_))
return reinterpret_borrow<pybind11::dtype>(ptr);
pybind11_fail("Unsupported buffer format!");
}
};

struct field_descriptor {
const char *name;
ssize_t offset;
Expand Down
Loading