forked from pymc-devs/pytensor
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Use NumPy C API to perform DimShuffle steps in its C implementation
- Loading branch information
Brandon T. Willard
committed
Dec 15, 2021
1 parent
223ee15
commit e593b0a
Showing
9 changed files
with
126 additions
and
162 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,104 +1,81 @@ | ||
#section support_code_apply | ||
|
||
int APPLY_SPECIFIC(cpu_dimshuffle)(PyArrayObject* input, PyArrayObject** res, PARAMS_TYPE* params) { | ||
npy_bool* input_broadcastable; | ||
npy_int64* new_order; | ||
npy_intp nd_in; | ||
npy_intp nd_out; | ||
PyArrayObject* basename; | ||
npy_intp* dimensions; | ||
npy_intp* strides; | ||
|
||
if (!PyArray_IS_C_CONTIGUOUS(params->input_broadcastable)) { | ||
PyErr_SetString(PyExc_RuntimeError, "DimShuffle: param input_broadcastable must be C-contiguous."); | ||
return 1; | ||
} | ||
if (!PyArray_IS_C_CONTIGUOUS(params->_new_order)) { | ||
PyErr_SetString(PyExc_RuntimeError, "DimShuffle: param _new_order must be C-contiguous."); | ||
return 1; | ||
} | ||
input_broadcastable = (npy_bool*) PyArray_DATA(params->input_broadcastable); | ||
new_order = (npy_int64*) PyArray_DATA(params->_new_order); | ||
nd_in = PyArray_SIZE(params->input_broadcastable); | ||
nd_out = PyArray_SIZE(params->_new_order); | ||
|
||
/* check_input_nd */ | ||
if (PyArray_NDIM(input) != nd_in) { | ||
PyErr_SetString(PyExc_NotImplementedError, "input nd"); | ||
return 1; | ||
} | ||
int APPLY_SPECIFIC(cpu_dimshuffle)(PyArrayObject *input, PyArrayObject **res, | ||
PARAMS_TYPE *params) { | ||
|
||
/* clear_output */ | ||
if (*res) | ||
Py_XDECREF(*res); | ||
// This points to either the original input or a copy we create below. | ||
// Either way, this is what we should be working on/with. | ||
PyArrayObject *_input; | ||
|
||
/* get_base */ | ||
if (params->inplace) { | ||
basename = input; | ||
Py_INCREF((PyObject*)basename); | ||
} else { | ||
basename = | ||
(PyArrayObject*)PyArray_FromAny((PyObject*)input, | ||
NULL, 0, 0, NPY_ARRAY_ALIGNED|NPY_ARRAY_ENSURECOPY, NULL); | ||
} | ||
if (*res) | ||
Py_XDECREF(*res); | ||
|
||
/* shape_statements and strides_statements */ | ||
dimensions = (npy_intp*) malloc(nd_out * sizeof(npy_intp)); | ||
strides = (npy_intp*) malloc(nd_out * sizeof(npy_intp)); | ||
if (dimensions == NULL || strides == NULL) { | ||
PyErr_NoMemory(); | ||
free(dimensions); | ||
free(strides); | ||
return 1; | ||
}; | ||
|
||
for (npy_intp i = 0; i < nd_out; ++i) { | ||
if (new_order[i] != -1) { | ||
dimensions[i] = PyArray_DIMS(basename)[new_order[i]]; | ||
strides[i] = PyArray_DIMS(basename)[new_order[i]] == 1 ? | ||
0 : PyArray_STRIDES(basename)[new_order[i]]; | ||
} else { | ||
dimensions[i] = 1; | ||
strides[i] = 0; | ||
} | ||
} | ||
if (params->inplace) { | ||
_input = input; | ||
Py_INCREF((PyObject *)_input); | ||
} else { | ||
_input = (PyArrayObject *)PyArray_FromAny( | ||
(PyObject *)input, NULL, 0, 0, NPY_ARRAY_ALIGNED | NPY_ARRAY_ENSURECOPY, | ||
NULL); | ||
} | ||
|
||
/* set the strides of the broadcasted dimensions. | ||
* This algorithm is from numpy: PyArray_Newshape() in | ||
* cvs/numpy/numpy/core/src/multiarraymodule.c */ | ||
if (nd_out > 0) { | ||
if (strides[nd_out - 1] == 0) | ||
strides[nd_out - 1] = PyArray_DESCR(basename)->elsize; | ||
for (npy_intp i = nd_out - 2; i > -1; --i) { | ||
if (strides[i] == 0) | ||
strides[i] = strides[i + 1] * dimensions[i + 1]; | ||
} | ||
} | ||
PyArray_Dims permute; | ||
|
||
if (!PyArray_IntpConverter((PyObject *)params->transposition, &permute)) { | ||
return 1; | ||
} | ||
|
||
/* close_bracket */ | ||
// create a new array. | ||
*res = (PyArrayObject*)PyArray_New(&PyArray_Type, nd_out, dimensions, | ||
PyArray_TYPE(basename), strides, | ||
PyArray_DATA(basename), PyArray_ITEMSIZE(basename), | ||
// borrow only the writable flag from the base | ||
// the NPY_OWNDATA flag will default to 0. | ||
(NPY_ARRAY_WRITEABLE * PyArray_ISWRITEABLE(basename)), | ||
NULL); | ||
|
||
if (*res == NULL) { | ||
free(dimensions); | ||
free(strides); | ||
return 1; | ||
/* | ||
res = res.transpose(self.transposition) | ||
*/ | ||
PyArrayObject *transposed_input = | ||
(PyArrayObject *)PyArray_Transpose(_input, &permute); | ||
|
||
PyDimMem_FREE(permute.ptr); | ||
|
||
npy_intp *res_shape = PyArray_DIMS(transposed_input); | ||
npy_intp N_shuffle = PyArray_SIZE(params->shuffle); | ||
npy_intp N_augment = PyArray_SIZE(params->augment); | ||
npy_intp N = N_augment + N_shuffle; | ||
npy_intp *_reshape_shape = (npy_intp *)malloc(N * sizeof(npy_intp)); | ||
|
||
if (_reshape_shape == NULL) { | ||
PyErr_NoMemory(); | ||
free(_reshape_shape); | ||
return 1; | ||
} | ||
|
||
/* | ||
shape = list(res.shape[: len(self.shuffle)]) | ||
for augm in self.augment: | ||
shape.insert(augm, 1) | ||
*/ | ||
npy_intp aug_idx = 0; | ||
int res_idx = 0; | ||
for (npy_intp i = 0; i < N; i++) { | ||
if (aug_idx < N_augment && | ||
i == *((npy_intp *)PyArray_GetPtr(params->augment, &aug_idx))) { | ||
_reshape_shape[i] = 1; | ||
aug_idx++; | ||
} else { | ||
_reshape_shape[i] = res_shape[res_idx]; | ||
res_idx++; | ||
} | ||
} | ||
|
||
PyArray_Dims reshape_shape = {.ptr = _reshape_shape, .len = (int)N}; | ||
|
||
/* res = res.reshape(shape) */ | ||
*res = (PyArrayObject *)PyArray_Newshape(transposed_input, &reshape_shape, | ||
NPY_CORDER); | ||
|
||
// recalculate flags: CONTIGUOUS, FORTRAN, ALIGNED | ||
PyArray_UpdateFlags(*res, NPY_ARRAY_UPDATE_ALL); | ||
/* Py_XDECREF(transposed_input); */ | ||
|
||
// we are making a view in both inplace and non-inplace cases | ||
PyArray_SetBaseObject(*res, (PyObject*)basename); | ||
PyDimMem_FREE(reshape_shape.ptr); | ||
|
||
free(strides); | ||
free(dimensions); | ||
if (!*res) { | ||
return 1; | ||
} | ||
|
||
return 0; | ||
return 0; | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.