-
-
Notifications
You must be signed in to change notification settings - Fork 151
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Use NumPy C API to perform DimShuffle steps in its C implementation
- Loading branch information
1 parent
e1f6f13
commit 4d7c16b
Showing
9 changed files
with
124 additions
and
162 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,104 +1,86 @@ | ||
#section support_code_apply | ||
|
||
int APPLY_SPECIFIC(cpu_dimshuffle)(PyArrayObject* input, PyArrayObject** res, PARAMS_TYPE* params) { | ||
npy_bool* input_broadcastable; | ||
npy_int64* new_order; | ||
npy_intp nd_in; | ||
npy_intp nd_out; | ||
PyArrayObject* basename; | ||
npy_intp* dimensions; | ||
npy_intp* strides; | ||
|
||
if (!PyArray_IS_C_CONTIGUOUS(params->input_broadcastable)) { | ||
PyErr_SetString(PyExc_RuntimeError, "DimShuffle: param input_broadcastable must be C-contiguous."); | ||
return 1; | ||
} | ||
if (!PyArray_IS_C_CONTIGUOUS(params->_new_order)) { | ||
PyErr_SetString(PyExc_RuntimeError, "DimShuffle: param _new_order must be C-contiguous."); | ||
return 1; | ||
} | ||
input_broadcastable = (npy_bool*) PyArray_DATA(params->input_broadcastable); | ||
new_order = (npy_int64*) PyArray_DATA(params->_new_order); | ||
nd_in = PyArray_SIZE(params->input_broadcastable); | ||
nd_out = PyArray_SIZE(params->_new_order); | ||
|
||
/* check_input_nd */ | ||
if (PyArray_NDIM(input) != nd_in) { | ||
PyErr_SetString(PyExc_NotImplementedError, "input nd"); | ||
return 1; | ||
} | ||
int APPLY_SPECIFIC(cpu_dimshuffle)(PyArrayObject *input, PyArrayObject **res, | ||
PARAMS_TYPE *params) { | ||
|
||
/* clear_output */ | ||
if (*res) | ||
Py_XDECREF(*res); | ||
// This points to either the original input or a copy we create below. | ||
// Either way, this is what we should be working on/with. | ||
PyArrayObject *_input; | ||
|
||
/* get_base */ | ||
if (params->inplace) { | ||
basename = input; | ||
Py_INCREF((PyObject*)basename); | ||
} else { | ||
basename = | ||
(PyArrayObject*)PyArray_FromAny((PyObject*)input, | ||
NULL, 0, 0, NPY_ARRAY_ALIGNED|NPY_ARRAY_ENSURECOPY, NULL); | ||
} | ||
if (!PyArray_IS_C_CONTIGUOUS(params->augment)) { | ||
PyErr_SetString(PyExc_RuntimeError, "DimShuffle: param input_broadcastable must be C-contiguous."); | ||
return 1; | ||
} | ||
|
||
/* shape_statements and strides_statements */ | ||
dimensions = (npy_intp*) malloc(nd_out * sizeof(npy_intp)); | ||
strides = (npy_intp*) malloc(nd_out * sizeof(npy_intp)); | ||
if (dimensions == NULL || strides == NULL) { | ||
PyErr_NoMemory(); | ||
free(dimensions); | ||
free(strides); | ||
return 1; | ||
}; | ||
|
||
for (npy_intp i = 0; i < nd_out; ++i) { | ||
if (new_order[i] != -1) { | ||
dimensions[i] = PyArray_DIMS(basename)[new_order[i]]; | ||
strides[i] = PyArray_DIMS(basename)[new_order[i]] == 1 ? | ||
0 : PyArray_STRIDES(basename)[new_order[i]]; | ||
} else { | ||
dimensions[i] = 1; | ||
strides[i] = 0; | ||
} | ||
} | ||
if (*res) | ||
Py_XDECREF(*res); | ||
|
||
/* set the strides of the broadcasted dimensions. | ||
* This algorithm is from numpy: PyArray_Newshape() in | ||
* cvs/numpy/numpy/core/src/multiarraymodule.c */ | ||
if (nd_out > 0) { | ||
if (strides[nd_out - 1] == 0) | ||
strides[nd_out - 1] = PyArray_DESCR(basename)->elsize; | ||
for (npy_intp i = nd_out - 2; i > -1; --i) { | ||
if (strides[i] == 0) | ||
strides[i] = strides[i + 1] * dimensions[i + 1]; | ||
} | ||
} | ||
if (params->inplace) { | ||
_input = input; | ||
Py_INCREF((PyObject *)_input); | ||
} else { | ||
_input = (PyArrayObject *)PyArray_FromAny( | ||
(PyObject *)input, NULL, 0, 0, NPY_ARRAY_ALIGNED | NPY_ARRAY_ENSURECOPY, | ||
NULL); | ||
} | ||
|
||
PyArray_Dims permute; | ||
|
||
/* close_bracket */ | ||
// create a new array. | ||
*res = (PyArrayObject*)PyArray_New(&PyArray_Type, nd_out, dimensions, | ||
PyArray_TYPE(basename), strides, | ||
PyArray_DATA(basename), PyArray_ITEMSIZE(basename), | ||
// borrow only the writable flag from the base | ||
// the NPY_OWNDATA flag will default to 0. | ||
(NPY_ARRAY_WRITEABLE * PyArray_ISWRITEABLE(basename)), | ||
NULL); | ||
|
||
if (*res == NULL) { | ||
free(dimensions); | ||
free(strides); | ||
return 1; | ||
if (!PyArray_IntpConverter((PyObject *)params->transposition, &permute)) { | ||
return 1; | ||
} | ||
|
||
/* | ||
res = res.transpose(self.transposition) | ||
*/ | ||
PyArrayObject *transposed_input = | ||
(PyArrayObject *)PyArray_Transpose(_input, &permute); | ||
|
||
PyDimMem_FREE(permute.ptr); | ||
|
||
npy_intp *res_shape = PyArray_DIMS(transposed_input); | ||
npy_intp *augment = (npy_intp *)PyArray_DATA(params->augment); | ||
npy_intp N_shuffle = PyArray_SIZE(params->shuffle); | ||
npy_intp N_augment = PyArray_SIZE(params->augment); | ||
npy_intp N = N_augment + N_shuffle; | ||
npy_intp *_reshape_shape = (npy_intp *)malloc(N * sizeof(npy_intp)); | ||
|
||
if (_reshape_shape == NULL) { | ||
PyErr_NoMemory(); | ||
free(_reshape_shape); | ||
return 1; | ||
} | ||
|
||
/* | ||
shape = list(res.shape[: len(self.shuffle)]) | ||
for augm in self.augment: | ||
shape.insert(augm, 1) | ||
*/ | ||
int aug_idx = 0; | ||
int res_idx = 0; | ||
for (npy_intp i = 0; i < N; i++) { | ||
if (aug_idx < N_augment && i == augment[aug_idx]) { | ||
_reshape_shape[i] = 1; | ||
aug_idx++; | ||
} else { | ||
_reshape_shape[i] = res_shape[res_idx]; | ||
res_idx++; | ||
} | ||
} | ||
|
||
PyArray_Dims reshape_shape = {.ptr = _reshape_shape, .len = N}; | ||
|
||
/* res = res.reshape(shape) */ | ||
*res = (PyArrayObject *)PyArray_Newshape(transposed_input, &reshape_shape, | ||
NPY_CORDER); | ||
|
||
// recalculate flags: CONTIGUOUS, FORTRAN, ALIGNED | ||
PyArray_UpdateFlags(*res, NPY_ARRAY_UPDATE_ALL); | ||
/* Py_XDECREF(transposed_input); */ | ||
|
||
// we are making a view in both inplace and non-inplace cases | ||
PyArray_SetBaseObject(*res, (PyObject*)basename); | ||
PyDimMem_FREE(reshape_shape.ptr); | ||
|
||
free(strides); | ||
free(dimensions); | ||
if (!*res) { | ||
return 1; | ||
} | ||
|
||
return 0; | ||
return 0; | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.