Skip to content

Commit

Permalink
Use NumPy C API to perform DimShuffle steps in its C implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
brandonwillard committed Dec 15, 2021
1 parent e1f6f13 commit 4d7c16b
Show file tree
Hide file tree
Showing 9 changed files with 124 additions and 162 deletions.
2 changes: 1 addition & 1 deletion aesara/gpuarray/elemwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -468,7 +468,7 @@ def perform(self, node, inp, out, params):

res = input

res = res.transpose(self.shuffle + self.drop)
res = res.transpose(self.transposition)

shape = list(res.shape[: len(self.shuffle)])
for augm in self.augment:
Expand Down
2 changes: 1 addition & 1 deletion aesara/link/jax/dispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -710,7 +710,7 @@ def reshape(x, shape):
def jax_funcify_DimShuffle(op, **kwargs):
def dimshuffle(x):

res = jnp.transpose(x, op.shuffle + op.drop)
res = jnp.transpose(x, op.transposition)

shape = list(res.shape[: len(op.shuffle)])

Expand Down
4 changes: 2 additions & 2 deletions aesara/link/numba/dispatch/elemwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,7 +319,7 @@ def numba_funcify_CAReduce(op, node, **kwargs):
@numba_funcify.register(DimShuffle)
def numba_funcify_DimShuffle(op, **kwargs):
shuffle = tuple(op.shuffle)
drop = tuple(op.drop)
transposition = tuple(op.transposition)
augment = tuple(op.augment)
inplace = op.inplace

Expand Down Expand Up @@ -352,7 +352,7 @@ def populate_new_shape(i, j, new_shape, shuffle_shape):

@numba.njit
def dimshuffle_inner(x, shuffle):
res = np.transpose(x, shuffle + drop)
res = np.transpose(x, transposition)
shuffle_shape = res.shape[: len(shuffle)]

new_shape = create_zeros_tuple()
Expand Down
164 changes: 73 additions & 91 deletions aesara/tensor/c_code/dimshuffle.c
Original file line number Diff line number Diff line change
@@ -1,104 +1,86 @@
#section support_code_apply

int APPLY_SPECIFIC(cpu_dimshuffle)(PyArrayObject* input, PyArrayObject** res, PARAMS_TYPE* params) {
npy_bool* input_broadcastable;
npy_int64* new_order;
npy_intp nd_in;
npy_intp nd_out;
PyArrayObject* basename;
npy_intp* dimensions;
npy_intp* strides;

if (!PyArray_IS_C_CONTIGUOUS(params->input_broadcastable)) {
PyErr_SetString(PyExc_RuntimeError, "DimShuffle: param input_broadcastable must be C-contiguous.");
return 1;
}
if (!PyArray_IS_C_CONTIGUOUS(params->_new_order)) {
PyErr_SetString(PyExc_RuntimeError, "DimShuffle: param _new_order must be C-contiguous.");
return 1;
}
input_broadcastable = (npy_bool*) PyArray_DATA(params->input_broadcastable);
new_order = (npy_int64*) PyArray_DATA(params->_new_order);
nd_in = PyArray_SIZE(params->input_broadcastable);
nd_out = PyArray_SIZE(params->_new_order);

/* check_input_nd */
if (PyArray_NDIM(input) != nd_in) {
PyErr_SetString(PyExc_NotImplementedError, "input nd");
return 1;
}
int APPLY_SPECIFIC(cpu_dimshuffle)(PyArrayObject *input, PyArrayObject **res,
PARAMS_TYPE *params) {

/* clear_output */
if (*res)
Py_XDECREF(*res);
// This points to either the original input or a copy we create below.
// Either way, this is what we should be working on/with.
PyArrayObject *_input;

/* get_base */
if (params->inplace) {
basename = input;
Py_INCREF((PyObject*)basename);
} else {
basename =
(PyArrayObject*)PyArray_FromAny((PyObject*)input,
NULL, 0, 0, NPY_ARRAY_ALIGNED|NPY_ARRAY_ENSURECOPY, NULL);
}
if (!PyArray_IS_C_CONTIGUOUS(params->augment)) {
PyErr_SetString(PyExc_RuntimeError, "DimShuffle: param input_broadcastable must be C-contiguous.");
return 1;
}

/* shape_statements and strides_statements */
dimensions = (npy_intp*) malloc(nd_out * sizeof(npy_intp));
strides = (npy_intp*) malloc(nd_out * sizeof(npy_intp));
if (dimensions == NULL || strides == NULL) {
PyErr_NoMemory();
free(dimensions);
free(strides);
return 1;
};

for (npy_intp i = 0; i < nd_out; ++i) {
if (new_order[i] != -1) {
dimensions[i] = PyArray_DIMS(basename)[new_order[i]];
strides[i] = PyArray_DIMS(basename)[new_order[i]] == 1 ?
0 : PyArray_STRIDES(basename)[new_order[i]];
} else {
dimensions[i] = 1;
strides[i] = 0;
}
}
if (*res)
Py_XDECREF(*res);

/* set the strides of the broadcasted dimensions.
* This algorithm is from numpy: PyArray_Newshape() in
* cvs/numpy/numpy/core/src/multiarraymodule.c */
if (nd_out > 0) {
if (strides[nd_out - 1] == 0)
strides[nd_out - 1] = PyArray_DESCR(basename)->elsize;
for (npy_intp i = nd_out - 2; i > -1; --i) {
if (strides[i] == 0)
strides[i] = strides[i + 1] * dimensions[i + 1];
}
}
if (params->inplace) {
_input = input;
Py_INCREF((PyObject *)_input);
} else {
_input = (PyArrayObject *)PyArray_FromAny(
(PyObject *)input, NULL, 0, 0, NPY_ARRAY_ALIGNED | NPY_ARRAY_ENSURECOPY,
NULL);
}

PyArray_Dims permute;

/* close_bracket */
// create a new array.
*res = (PyArrayObject*)PyArray_New(&PyArray_Type, nd_out, dimensions,
PyArray_TYPE(basename), strides,
PyArray_DATA(basename), PyArray_ITEMSIZE(basename),
// borrow only the writable flag from the base
// the NPY_OWNDATA flag will default to 0.
(NPY_ARRAY_WRITEABLE * PyArray_ISWRITEABLE(basename)),
NULL);

if (*res == NULL) {
free(dimensions);
free(strides);
return 1;
if (!PyArray_IntpConverter((PyObject *)params->transposition, &permute)) {
return 1;
}

/*
res = res.transpose(self.transposition)
*/
PyArrayObject *transposed_input =
(PyArrayObject *)PyArray_Transpose(_input, &permute);

PyDimMem_FREE(permute.ptr);

npy_intp *res_shape = PyArray_DIMS(transposed_input);
npy_intp *augment = (npy_intp *)PyArray_DATA(params->augment);
npy_intp N_shuffle = PyArray_SIZE(params->shuffle);
npy_intp N_augment = PyArray_SIZE(params->augment);
npy_intp N = N_augment + N_shuffle;
npy_intp *_reshape_shape = (npy_intp *)malloc(N * sizeof(npy_intp));

if (_reshape_shape == NULL) {
PyErr_NoMemory();
free(_reshape_shape);
return 1;
}

/*
shape = list(res.shape[: len(self.shuffle)])
for augm in self.augment:
shape.insert(augm, 1)
*/
int aug_idx = 0;
int res_idx = 0;
for (npy_intp i = 0; i < N; i++) {
if (aug_idx < N_augment && i == augment[aug_idx]) {
_reshape_shape[i] = 1;
aug_idx++;
} else {
_reshape_shape[i] = res_shape[res_idx];
res_idx++;
}
}

PyArray_Dims reshape_shape = {.ptr = _reshape_shape, .len = N};

/* res = res.reshape(shape) */
*res = (PyArrayObject *)PyArray_Newshape(transposed_input, &reshape_shape,
NPY_CORDER);

// recalculate flags: CONTIGUOUS, FORTRAN, ALIGNED
PyArray_UpdateFlags(*res, NPY_ARRAY_UPDATE_ALL);
/* Py_XDECREF(transposed_input); */

// we are making a view in both inplace and non-inplace cases
PyArray_SetBaseObject(*res, (PyObject*)basename);
PyDimMem_FREE(reshape_shape.ptr);

free(strides);
free(dimensions);
if (!*res) {
return 1;
}

return 0;
return 0;
}
79 changes: 27 additions & 52 deletions aesara/tensor/elemwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,47 +119,27 @@ class DimShuffle(ExternalCOp):

@property
def params_type(self):
# We can't directly create `params_type` as class attribute
# because of importation issues related to TensorType.
return ParamsType(
input_broadcastable=TensorType(dtype="bool", broadcastable=(False,)),
_new_order=lvector,
transposition=TensorType(dtype="uint32", broadcastable=(False,)),
shuffle=lvector,
augment=lvector,
transposition=lvector,
inplace=scalar_bool,
)

@property
def _new_order(self):
# Param for C code.
# self.new_order may contain 'x', which is not a valid integer value.
# We replace it with -1.
return [(-1 if x == "x" else x) for x in self.new_order]

@property
def transposition(self):
return self.shuffle + self.drop

def __init__(self, input_broadcastable, new_order, inplace=True):
def __init__(self, input_broadcastable, new_order):
super().__init__([self.c_func_file], self.c_func_name)

self.input_broadcastable = tuple(input_broadcastable)
self.new_order = tuple(new_order)
if inplace is True:
self.inplace = inplace
else:
raise ValueError(
"DimShuffle is inplace by default and hence the inplace for DimShuffle must be true"
)

self.inplace = True

for i, j in enumerate(new_order):
if j != "x":
# There is a bug in numpy that results in
# isinstance(x, integer_types) returning False for
# numpy integers. See
# <http://projects.scipy.org/numpy/ticket/2235>.
if not isinstance(j, (int, np.integer)):
raise TypeError(
"DimShuffle indices must be python ints. "
f"Got: '{j}' of type '{type(j)}'."
"DimShuffle indices must be Python ints; got "
f"{j} of type {type(j)}."
)
if j >= len(input_broadcastable):
raise ValueError(
Expand All @@ -169,31 +149,30 @@ def __init__(self, input_broadcastable, new_order, inplace=True):
if j in new_order[(i + 1) :]:
raise ValueError(
"The same input dimension may not appear "
"twice in the list of output dimensions",
new_order,
f"twice in the list of output dimensions: {new_order}"
)

# list of dimensions of the input to drop
self.drop = []
# List of input dimensions to drop
drop = []
for i, b in enumerate(input_broadcastable):
if i not in new_order:
# we want to drop this dimension because it's not a value in
# new_order
if b == 1: # 1 aka True
self.drop.append(i)
# We want to drop this dimension because it's not a value in
# `new_order`
if b == 1:
drop.append(i)
else:
# we cannot drop non-broadcastable dimensions
# We cannot drop non-broadcastable dimensions
raise ValueError(
"You cannot drop a non-broadcastable dimension:",
f" {input_broadcastable}, {new_order}",
"Cannot drop a non-broadcastable dimension: "
f"{input_broadcastable}, {new_order}"
)

# this is the list of the original dimensions that we keep
# This is the list of the original dimensions that we keep
self.shuffle = [x for x in new_order if x != "x"]

# list of dimensions of the output that are broadcastable and were not
self.transposition = self.shuffle + drop
# List of dimensions of the output that are broadcastable and were not
# in the original input
self.augment = [i for i, x in enumerate(new_order) if x == "x"]
self.augment = sorted([i for i, x in enumerate(new_order) if x == "x"])

if self.inplace:
self.view_map = {0: [0]}
Expand Down Expand Up @@ -241,27 +220,23 @@ def __str__(self):
return "DimShuffle{%s}" % ",".join(str(x) for x in self.new_order)

def perform(self, node, inp, out, params):
(input,) = inp
(res,) = inp
(storage,) = out
# drop
res = input

if type(res) != np.ndarray and type(res) != np.memmap:
raise TypeError(res)

# transpose
res = res.transpose(self.shuffle + self.drop)
res = res.transpose(self.transposition)

# augment
shape = list(res.shape[: len(self.shuffle)])
for augm in self.augment:
shape.insert(augm, 1)
res = res.reshape(shape)

# copy (if not inplace)
if not self.inplace:
res = np.copy(res)

storage[0] = np.asarray(res) # asarray puts scalars back into array
storage[0] = np.asarray(res)

def infer_shape(self, fgraph, node, shapes):
(ishp,) = shapes
Expand Down
2 changes: 1 addition & 1 deletion aesara/tensor/inplace.py
Original file line number Diff line number Diff line change
Expand Up @@ -399,4 +399,4 @@ def conj_inplace(a):
def transpose_inplace(x, **kwargs):
"Perform a transpose on a tensor without copying the underlying storage"
dims = list(range(x.ndim - 1, -1, -1))
return DimShuffle(x.broadcastable, dims, inplace=True)(x)
return DimShuffle(x.broadcastable, dims)(x)
2 changes: 1 addition & 1 deletion tests/link/test_jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -856,7 +856,7 @@ def test_jax_Dimshuffle():
compare_jax_and_py(x_fg, [np.c_[[1.0, 2.0, 3.0, 4.0]].astype(config.floatX)])

a_aet = tensor(dtype=config.floatX, broadcastable=[False, True])
x = aet_elemwise.DimShuffle([False, True], (0,), inplace=True)(a_aet)
x = aet_elemwise.DimShuffle([False, True], (0,))(a_aet)
x_fg = FunctionGraph([a_aet], [x])
compare_jax_and_py(x_fg, [np.c_[[1.0, 2.0, 3.0, 4.0]].astype(config.floatX)])

Expand Down
Loading

0 comments on commit 4d7c16b

Please sign in to comment.