Skip to content

Commit

Permalink
Take (#688)
Browse files Browse the repository at this point in the history
* add numpy.take
  • Loading branch information
v923z authored Oct 9, 2024
1 parent c0b3262 commit 2b74236
Show file tree
Hide file tree
Showing 14 changed files with 544 additions and 19 deletions.
231 changes: 230 additions & 1 deletion code/numpy/create.c
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
* The MIT License (MIT)
*
* Copyright (c) 2020 Jeff Epler for Adafruit Industries
* 2019-2021 Zoltán Vörös
* 2019-2024 Zoltán Vörös
* 2020 Taku Fukada
*/

Expand Down Expand Up @@ -776,6 +776,235 @@ mp_obj_t create_ones(size_t n_args, const mp_obj_t *pos_args, mp_map_t *kw_args)
MP_DEFINE_CONST_FUN_OBJ_KW(create_ones_obj, 0, create_ones);
#endif

#if ULAB_NUMPY_HAS_TAKE
//| def take(
//| a: ulab.numpy.ndarray,
//| indices: _ArrayLike,
//| axis: Optional[int] = None,
//| out: Optional[ulab.numpy.ndarray] = None,
//| mode: Optional[str] = None) -> ulab.numpy.ndarray:
//| """
//| .. param: a
//| The source array.
//| .. param: indices
//| The indices of the values to extract.
//| .. param: axis
//| The axis over which to select values. By default, the flattened input array is used.
//| .. param: out
//| If provided, the result will be placed in this array. It should be of the appropriate shape and dtype.
//| .. param: mode
//| Specifies how out-of-bounds indices will behave.
//| - `raise`: raise an error (default)
//| - `wrap`: wrap around
//| - `clip`: clip to the range
//| `clip` mode means that all indices that are too large are replaced by the
//| index that addresses the last element along that axis. Note that this disables
//| indexing with negative numbers.
//|
//| Return a new array."""
//| ...
//|

enum CREATE_TAKE_MODE {
CREATE_TAKE_RAISE,
CREATE_TAKE_WRAP,
CREATE_TAKE_CLIP,
};

mp_obj_t create_take(size_t n_args, const mp_obj_t *pos_args, mp_map_t *kw_args) {
static const mp_arg_t allowed_args[] = {
{ MP_QSTR_, MP_ARG_REQUIRED | MP_ARG_OBJ, { .u_obj = MP_OBJ_NULL } },
{ MP_QSTR_, MP_ARG_REQUIRED | MP_ARG_OBJ, { .u_obj = MP_OBJ_NULL } },
{ MP_QSTR_axis, MP_ARG_KW_ONLY | MP_ARG_OBJ, { .u_rom_obj = MP_ROM_NONE } },
{ MP_QSTR_out, MP_ARG_KW_ONLY | MP_ARG_OBJ, { .u_rom_obj = MP_ROM_NONE } },
{ MP_QSTR_mode, MP_ARG_KW_ONLY | MP_ARG_OBJ, { .u_rom_obj = MP_ROM_NONE } },
};

mp_arg_val_t args[MP_ARRAY_SIZE(allowed_args)];
mp_arg_parse_all(n_args, pos_args, kw_args, MP_ARRAY_SIZE(allowed_args), allowed_args, args);

if(!mp_obj_is_type(args[0].u_obj, &ulab_ndarray_type)) {
mp_raise_TypeError(MP_ERROR_TEXT("input is not an array"));
}

ndarray_obj_t *a = MP_OBJ_TO_PTR(args[0].u_obj);
int8_t axis = 0;
int8_t axis_index = 0;
int32_t axis_len;
uint8_t mode = CREATE_TAKE_RAISE;
uint8_t ndim;

// axis keyword argument
if(args[2].u_obj == mp_const_none) {
// work with the flattened array
axis_len = a->len;
ndim = 1;
} else { // i.e., axis is an integer
// TODO: this pops up at quite a few places, write it as a function
axis = mp_obj_get_int(args[2].u_obj);
ndim = a->ndim;
if(axis < 0) axis += a->ndim;
if((axis < 0) || (axis > a->ndim - 1)) {
mp_raise_ValueError(MP_ERROR_TEXT("index out of range"));
}
axis_index = ULAB_MAX_DIMS - a->ndim + axis;
axis_len = (int32_t)a->shape[axis_index];
}

size_t _len;
// mode keyword argument
if(mp_obj_is_str(args[4].u_obj)) {
const char *_mode = mp_obj_str_get_data(args[4].u_obj, &_len);
if(memcmp(_mode, "raise", 5) == 0) {
mode = CREATE_TAKE_RAISE;
} else if(memcmp(_mode, "wrap", 4) == 0) {
mode = CREATE_TAKE_WRAP;
} else if(memcmp(_mode, "clip", 4) == 0) {
mode = CREATE_TAKE_CLIP;
} else {
mp_raise_ValueError(MP_ERROR_TEXT("mode should be raise, wrap or clip"));
}
}

size_t indices_len = (size_t)mp_obj_get_int(mp_obj_len_maybe(args[1].u_obj));

size_t *indices = m_new(size_t, indices_len);

mp_obj_iter_buf_t buf;
mp_obj_t item, iterable = mp_getiter(args[1].u_obj, &buf);

size_t z = 0;
while((item = mp_iternext(iterable)) != MP_OBJ_STOP_ITERATION) {
int32_t index = mp_obj_get_int(item);
if(mode == CREATE_TAKE_RAISE) {
if(index < 0) {
index += axis_len;
}
if((index < 0) || (index > axis_len - 1)) {
m_del(size_t, indices, indices_len);
mp_raise_ValueError(MP_ERROR_TEXT("index out of range"));
}
} else if(mode == CREATE_TAKE_WRAP) {
index %= axis_len;
} else { // mode == CREATE_TAKE_CLIP
if(index < 0) {
m_del(size_t, indices, indices_len);
mp_raise_ValueError(MP_ERROR_TEXT("index must not be negative"));
}
if(index > axis_len - 1) {
index = axis_len - 1;
}
}
indices[z++] = (size_t)index;
}

size_t *shape = m_new0(size_t, ULAB_MAX_DIMS);
if(args[2].u_obj == mp_const_none) { // flattened array
shape[ULAB_MAX_DIMS - 1] = indices_len;
} else {
for(uint8_t i = 0; i < ULAB_MAX_DIMS; i++) {
shape[i] = a->shape[i];
if(i == axis_index) {
shape[i] = indices_len;
}
}
}

ndarray_obj_t *out = NULL;
if(args[3].u_obj == mp_const_none) {
// no output was supplied
out = ndarray_new_dense_ndarray(ndim, shape, a->dtype);
} else {
// TODO: deal with last argument being false!
out = ulab_tools_inspect_out(args[3].u_obj, a->dtype, ndim, shape, true);
}

#if ULAB_MAX_DIMS > 1 // we can save the hassle, if there is only one possible dimension
if((args[2].u_obj == mp_const_none) || (a->ndim == 1)) { // flattened array
#endif
uint8_t *out_array = (uint8_t *)out->array;
for(size_t x = 0; x < indices_len; x++) {
uint8_t *a_array = (uint8_t *)a->array;
size_t remainder = indices[x];
uint8_t q = ULAB_MAX_DIMS - 1;
do {
size_t div = (remainder / a->shape[q]);
a_array += remainder * a->strides[q];
remainder -= div * a->shape[q];
q--;
} while(q > ULAB_MAX_DIMS - a->ndim);
// NOTE: for floats and complexes, this might be
// better with memcpy(out_array, a_array, a->itemsize)
for(uint8_t p = 0; p < a->itemsize; p++) {
out_array[p] = a_array[p];
}
out_array += a->itemsize;
}
#if ULAB_MAX_DIMS > 1
} else {
// move the axis shape/stride to the leftmost position:
SWAP(size_t, a->shape[0], a->shape[axis_index]);
SWAP(size_t, out->shape[0], out->shape[axis_index]);
SWAP(int32_t, a->strides[0], a->strides[axis_index]);
SWAP(int32_t, out->strides[0], out->strides[axis_index]);

for(size_t x = 0; x < indices_len; x++) {
uint8_t *a_array = (uint8_t *)a->array;
uint8_t *out_array = (uint8_t *)out->array;
a_array += indices[x] * a->strides[0];
out_array += x * out->strides[0];

#if ULAB_MAX_DIMS > 3
size_t j = 0;
do {
#endif
#if ULAB_MAX_DIMS > 2
size_t k = 0;
do {
#endif
size_t l = 0;
do {
// NOTE: for floats and complexes, this might be
// better with memcpy(out_array, a_array, a->itemsize)
for(uint8_t p = 0; p < a->itemsize; p++) {
out_array[p] = a_array[p];
}
out_array += out->strides[ULAB_MAX_DIMS - 1];
a_array += a->strides[ULAB_MAX_DIMS - 1];
l++;
} while(l < a->shape[ULAB_MAX_DIMS - 1]);
#if ULAB_MAX_DIMS > 2
out_array -= out->strides[ULAB_MAX_DIMS - 1] * out->shape[ULAB_MAX_DIMS - 1];
out_array += out->strides[ULAB_MAX_DIMS - 2];
a_array -= a->strides[ULAB_MAX_DIMS - 1] * a->shape[ULAB_MAX_DIMS - 1];
a_array += a->strides[ULAB_MAX_DIMS - 2];
k++;
} while(k < a->shape[ULAB_MAX_DIMS - 2]);
#endif
#if ULAB_MAX_DIMS > 3
out_array -= out->strides[ULAB_MAX_DIMS - 2] * out->shape[ULAB_MAX_DIMS - 2];
out_array += out->strides[ULAB_MAX_DIMS - 3];
a_array -= a->strides[ULAB_MAX_DIMS - 2] * a->shape[ULAB_MAX_DIMS - 2];
a_array += a->strides[ULAB_MAX_DIMS - 3];
j++;
} while(j < a->shape[ULAB_MAX_DIMS - 3]);
#endif
}

// revert back to the original order
SWAP(size_t, a->shape[0], a->shape[axis_index]);
SWAP(size_t, out->shape[0], out->shape[axis_index]);
SWAP(int32_t, a->strides[0], a->strides[axis_index]);
SWAP(int32_t, out->strides[0], out->strides[axis_index]);
}
#endif /* ULAB_MAX_DIMS > 1 */
m_del(size_t, indices, indices_len);
return MP_OBJ_FROM_PTR(out);
}

MP_DEFINE_CONST_FUN_OBJ_KW(create_take_obj, 2, create_take);
#endif /* ULAB_NUMPY_HAS_TAKE */

#if ULAB_NUMPY_HAS_ZEROS
//| def zeros(shape: Union[int, Tuple[int, ...]], *, dtype: _DType = ulab.numpy.float) -> ulab.numpy.ndarray:
//| """
Expand Down
5 changes: 5 additions & 0 deletions code/numpy/create.h
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,11 @@ mp_obj_t create_ones(size_t , const mp_obj_t *, mp_map_t *);
MP_DECLARE_CONST_FUN_OBJ_KW(create_ones_obj);
#endif

#if ULAB_NUMPY_HAS_TAKE
mp_obj_t create_take(size_t , const mp_obj_t *, mp_map_t *);
MP_DECLARE_CONST_FUN_OBJ_KW(create_take_obj);
#endif

#if ULAB_NUMPY_HAS_ZEROS
mp_obj_t create_zeros(size_t , const mp_obj_t *, mp_map_t *);
MP_DECLARE_CONST_FUN_OBJ_KW(create_zeros_obj);
Expand Down
3 changes: 3 additions & 0 deletions code/numpy/numpy.c
Original file line number Diff line number Diff line change
Expand Up @@ -291,6 +291,9 @@ static const mp_rom_map_elem_t ulab_numpy_globals_table[] = {
#if ULAB_NUMPY_HAS_SUM
{ MP_ROM_QSTR(MP_QSTR_sum), MP_ROM_PTR(&numerical_sum_obj) },
#endif
#if ULAB_NUMPY_HAS_TAKE
{ MP_ROM_QSTR(MP_QSTR_take), MP_ROM_PTR(&create_take_obj) },
#endif
// functions of the poly sub-module
#if ULAB_NUMPY_HAS_POLYFIT
{ MP_ROM_QSTR(MP_QSTR_polyfit), MP_ROM_PTR(&poly_polyfit_obj) },
Expand Down
2 changes: 1 addition & 1 deletion code/ulab.c
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
#include "user/user.h"
#include "utils/utils.h"

#define ULAB_VERSION 6.5.5
#define ULAB_VERSION 6.6.0
#define xstr(s) str(s)
#define str(s) #s

Expand Down
4 changes: 4 additions & 0 deletions code/ulab.h
Original file line number Diff line number Diff line change
Expand Up @@ -559,6 +559,10 @@
#define ULAB_NUMPY_HAS_SUM (1)
#endif

#ifndef ULAB_NUMPY_HAS_TAKE
#define ULAB_NUMPY_HAS_TAKE (1)
#endif

#ifndef ULAB_NUMPY_HAS_TRACE
#define ULAB_NUMPY_HAS_TRACE (1)
#endif
Expand Down
28 changes: 28 additions & 0 deletions code/ulab_tools.c
Original file line number Diff line number Diff line change
Expand Up @@ -274,3 +274,31 @@ bool ulab_tools_mp_obj_is_scalar(mp_obj_t obj) {
}
#endif
}

ndarray_obj_t *ulab_tools_inspect_out(mp_obj_t out, uint8_t dtype, uint8_t ndim, size_t *shape, bool dense_only) {
if(!mp_obj_is_type(out, &ulab_ndarray_type)) {
mp_raise_TypeError(MP_ERROR_TEXT("out has wrong type"));
}
ndarray_obj_t *ndarray = MP_OBJ_TO_PTR(out);

if(ndarray->dtype != dtype) {
mp_raise_ValueError(MP_ERROR_TEXT("out array has wrong dtype"));
}

if(ndarray->ndim != ndim) {
mp_raise_ValueError(MP_ERROR_TEXT("out array has wrong dimension"));
}

for(uint8_t i = 0; i < ULAB_MAX_DIMS; i++) {
if(ndarray->shape[i] != shape[i]) {
mp_raise_ValueError(MP_ERROR_TEXT("out array has wrong shape"));
}
}

if(dense_only) {
if(!ndarray_is_dense(ndarray)) {
mp_raise_ValueError(MP_ERROR_TEXT("output array must be contiguous"));
}
}
return ndarray;
}
5 changes: 2 additions & 3 deletions code/ulab_tools.h
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,6 @@ void ulab_rescale_float_strides(int32_t *);

bool ulab_tools_mp_obj_is_scalar(mp_obj_t );

#if ULAB_NUMPY_HAS_RANDOM_MODULE
ndarray_obj_t *ulab_tools_create_out(mp_obj_tuple_t , mp_obj_t , uint8_t , bool );
#endif
ndarray_obj_t *ulab_tools_inspect_out(mp_obj_t , uint8_t , uint8_t , size_t *, bool );

#endif
3 changes: 1 addition & 2 deletions docs/manual/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,7 @@
author = 'Zoltán Vörös'

# The full version, including alpha/beta/rc tags
release = '6.5.5'

release = '6.6.0'

# -- General configuration ---------------------------------------------------

Expand Down
Loading

0 comments on commit 2b74236

Please sign in to comment.