Skip to content

Commit

Permalink
fix axis index
Browse files Browse the repository at this point in the history
  • Loading branch information
v923z committed Oct 9, 2024
1 parent d2544f0 commit ccfb483
Showing 1 changed file with 8 additions and 8 deletions.
16 changes: 8 additions & 8 deletions code/numpy/create.c
Original file line number Diff line number Diff line change
Expand Up @@ -942,10 +942,10 @@ mp_obj_t create_take(size_t n_args, const mp_obj_t *pos_args, mp_map_t *kw_args)
#if ULAB_MAX_DIMS > 1
} else {
// move the axis shape/stride to the leftmost position:
SWAP(size_t, a->shape[0], a->shape[axis]);
SWAP(size_t, out->shape[0], out->shape[axis]);
SWAP(int32_t, a->strides[0], a->strides[axis]);
SWAP(int32_t, out->strides[0], out->strides[axis]);
SWAP(size_t, a->shape[0], a->shape[axis_index]);
SWAP(size_t, out->shape[0], out->shape[axis_index]);
SWAP(int32_t, a->strides[0], a->strides[axis_index]);
SWAP(int32_t, out->strides[0], out->strides[axis_index]);

for(size_t x = 0; x < indices_len; x++) {
uint8_t *a_array = (uint8_t *)a->array;
Expand Down Expand Up @@ -991,10 +991,10 @@ mp_obj_t create_take(size_t n_args, const mp_obj_t *pos_args, mp_map_t *kw_args)
}

// revert back to the original order
SWAP(size_t, a->shape[0], a->shape[axis]);
SWAP(size_t, out->shape[0], out->shape[axis]);
SWAP(int32_t, a->strides[0], a->strides[axis]);
SWAP(int32_t, out->strides[0], out->strides[axis]);
SWAP(size_t, a->shape[0], a->shape[axis_index]);
SWAP(size_t, out->shape[0], out->shape[axis_index]);
SWAP(int32_t, a->strides[0], a->strides[axis_index]);
SWAP(int32_t, out->strides[0], out->strides[axis_index]);
}
#endif /* ULAB_MAX_DIMS > 1 */
m_del(size_t, indices, indices_len);
Expand Down

0 comments on commit ccfb483

Please sign in to comment.