Skip to content

Commit

Permalink
Type promotion for indices arrays and casting vals in integer ind…
Browse files Browse the repository at this point in the history
…exing (#1647)

* Tweaks to advanced integer indexing

Setting items in an array now casts the right-hand side to the array data type when the data types differ

Setting and getting from an empty axis with non-empty indices now throws `IndexError`

* Integer advanced indexing now promotes indices arrays

* `put` now casts `vals` when the data type differs from `x`

Fixes `take` and `put` being used on non-empty axes with non-empty indices

Also adds a note to `put` about race conditions for non-unique indices

* Adds tests for indexing array casting for indices and values

* Fixes range when checking for empty axes in _take/_put_multi_index

Also corrects error raised in _put_multi_index when attempting to put into indices along an empty axis

* Changes per PR review
  • Loading branch information
ndgrigorian authored Apr 21, 2024
1 parent f5c6610 commit 7757857
Show file tree
Hide file tree
Showing 4 changed files with 237 additions and 60 deletions.
124 changes: 85 additions & 39 deletions dpctl/tensor/_copy_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -763,44 +763,66 @@ def _nonzero_impl(ary):

def _take_multi_index(ary, inds, p):
if not isinstance(ary, dpt.usm_ndarray):
raise TypeError
raise TypeError(
f"Expecting type dpctl.tensor.usm_ndarray, got {type(ary)}"
)
ary_nd = ary.ndim
p = normalize_axis_index(operator.index(p), ary_nd)
queues_ = [
ary.sycl_queue,
]
usm_types_ = [
ary.usm_type,
]
if not isinstance(inds, list) and not isinstance(inds, tuple):
if not isinstance(inds, (list, tuple)):
inds = (inds,)
all_integers = True
for ind in inds:
if not isinstance(ind, dpt.usm_ndarray):
raise TypeError("all elements of `ind` expected to be usm_ndarrays")
queues_.append(ind.sycl_queue)
usm_types_.append(ind.usm_type)
if all_integers:
all_integers = ind.dtype.kind in "ui"
if ind.dtype.kind not in "ui":
raise IndexError(
"arrays used as indices must be of integer (or boolean) type"
)
res_usm_type = dpctl.utils.get_coerced_usm_type(usm_types_)
exec_q = dpctl.utils.get_execution_queue(queues_)
if exec_q is None:
raise dpctl.utils.ExecutionPlacementError("")
if not all_integers:
raise IndexError(
"arrays used as indices must be of integer (or boolean) type"
raise dpctl.utils.ExecutionPlacementError(
"Can not automatically determine where to allocate the "
"result or performance execution. "
"Use `usm_ndarray.to_device` method to migrate data to "
"be associated with the same queue."
)
if len(inds) > 1:
ind_dt = dpt.result_type(*inds)
# ind arrays have been checked to be of integer dtype
if ind_dt.kind not in "ui":
raise ValueError(
"cannot safely promote indices to an integer data type"
)
inds = tuple(
map(
lambda ind: ind
if ind.dtype == ind_dt
else dpt.astype(ind, ind_dt),
inds,
)
)
inds = dpt.broadcast_arrays(*inds)
ary_ndim = ary.ndim
p = normalize_axis_index(operator.index(p), ary_ndim)

res_shape = ary.shape[:p] + inds[0].shape + ary.shape[p + len(inds) :]
res_usm_type = dpctl.utils.get_coerced_usm_type(usm_types_)
ind0 = inds[0]
ary_sh = ary.shape
p_end = p + len(inds)
if 0 in ary_sh[p:p_end] and ind0.size != 0:
raise IndexError("cannot take non-empty indices from an empty axis")
res_shape = ary_sh[:p] + ind0.shape + ary_sh[p_end:]
res = dpt.empty(
res_shape, dtype=ary.dtype, usm_type=res_usm_type, sycl_queue=exec_q
)

hev, _ = ti._take(
src=ary, ind=inds, dst=res, axis_start=p, mode=0, sycl_queue=exec_q
)
hev.wait()

return res


Expand Down Expand Up @@ -864,6 +886,12 @@ def _place_impl(ary, ary_mask, vals, axis=0):


def _put_multi_index(ary, inds, p, vals):
if not isinstance(ary, dpt.usm_ndarray):
raise TypeError(
f"Expecting type dpctl.tensor.usm_ndarray, got {type(ary)}"
)
ary_nd = ary.ndim
p = normalize_axis_index(operator.index(p), ary_nd)
if isinstance(vals, dpt.usm_ndarray):
queues_ = [ary.sycl_queue, vals.sycl_queue]
usm_types_ = [ary.usm_type, vals.usm_type]
Expand All @@ -874,46 +902,64 @@ def _put_multi_index(ary, inds, p, vals):
usm_types_ = [
ary.usm_type,
]
if not isinstance(inds, list) and not isinstance(inds, tuple):
if not isinstance(inds, (list, tuple)):
inds = (inds,)
all_integers = True
for ind in inds:
if not isinstance(ind, dpt.usm_ndarray):
raise TypeError
raise TypeError("all elements of `ind` expected to be usm_ndarrays")
queues_.append(ind.sycl_queue)
usm_types_.append(ind.usm_type)
if all_integers:
all_integers = ind.dtype.kind in "ui"
if ind.dtype.kind not in "ui":
raise IndexError(
"arrays used as indices must be of integer (or boolean) type"
)
vals_usm_type = dpctl.utils.get_coerced_usm_type(usm_types_)
exec_q = dpctl.utils.get_execution_queue(queues_)
if exec_q is not None:
if not isinstance(vals, dpt.usm_ndarray):
vals = dpt.asarray(
vals, dtype=ary.dtype, usm_type=vals_usm_type, sycl_queue=exec_q
)
else:
exec_q = dpctl.utils.get_execution_queue((exec_q, vals.sycl_queue))
if exec_q is None:
raise dpctl.utils.ExecutionPlacementError(
"Can not automatically determine where to allocate the "
"result or performance execution. "
"Use `usm_ndarray.to_device` method to migrate data to "
"be associated with the same queue."
)
if not all_integers:
raise IndexError(
"arrays used as indices must be of integer (or boolean) type"
)
if len(inds) > 1:
ind_dt = dpt.result_type(*inds)
# ind arrays have been checked to be of integer dtype
if ind_dt.kind not in "ui":
raise ValueError(
"cannot safely promote indices to an integer data type"
)
inds = tuple(
map(
lambda ind: ind
if ind.dtype == ind_dt
else dpt.astype(ind, ind_dt),
inds,
)
)
inds = dpt.broadcast_arrays(*inds)
ary_ndim = ary.ndim

p = normalize_axis_index(operator.index(p), ary_ndim)
vals_shape = ary.shape[:p] + inds[0].shape + ary.shape[p + len(inds) :]

vals_usm_type = dpctl.utils.get_coerced_usm_type(usm_types_)
if not isinstance(vals, dpt.usm_ndarray):
vals = dpt.asarray(
vals, dtype=ary.dtype, usm_type=vals_usm_type, sycl_queue=exec_q
ind0 = inds[0]
ary_sh = ary.shape
p_end = p + len(inds)
if 0 in ary_sh[p:p_end] and ind0.size != 0:
raise IndexError(
"cannot put into non-empty indices along an empty axis"
)

vals = dpt.broadcast_to(vals, vals_shape)

expected_vals_shape = ary_sh[:p] + ind0.shape + ary_sh[p_end:]
if vals.dtype == ary.dtype:
rhs = vals
else:
rhs = dpt.astype(vals, ary.dtype)
rhs = dpt.broadcast_to(rhs, expected_vals_shape)
hev, _ = ti._put(
dst=ary, ind=inds, val=vals, axis_start=p, mode=0, sycl_queue=exec_q
dst=ary, ind=inds, val=rhs, axis_start=p, mode=0, sycl_queue=exec_q
)
hev.wait()

return
61 changes: 40 additions & 21 deletions dpctl/tensor/_indexing_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@

import operator

import numpy as np
from numpy.core.numeric import normalize_axis_index

import dpctl
Expand Down Expand Up @@ -47,15 +46,15 @@ def take(x, indices, /, *, axis=None, mode="wrap"):
indices (usm_ndarray):
One-dimensional array of indices.
axis:
The axis over which the values will be selected.
If x is one-dimensional, this argument is optional.
Default: `None`.
The axis along which the values will be selected.
If ``x`` is one-dimensional, this argument is optional.
Default: ``None``.
mode:
How out-of-bounds indices will be handled.
"wrap" - clamps indices to (-n <= i < n), then wraps
``"wrap"`` - clamps indices to (-n <= i < n), then wraps
negative indices.
"clip" - clips indices to (0 <= i < n)
Default: `"wrap"`.
``"clip"`` - clips indices to (0 <= i < n)
Default: ``"wrap"``.
Returns:
usm_ndarray:
Expand All @@ -73,7 +72,7 @@ def take(x, indices, /, *, axis=None, mode="wrap"):
type(indices)
)
)
if not np.issubdtype(indices.dtype, np.integer):
if indices.dtype.kind not in "ui":
raise IndexError(
"`indices` expected integer data type, got `{}`".format(
indices.dtype
Expand Down Expand Up @@ -104,6 +103,9 @@ def take(x, indices, /, *, axis=None, mode="wrap"):

if x_ndim > 0:
axis = normalize_axis_index(operator.index(axis), x_ndim)
x_sh = x.shape
if x_sh[axis] == 0 and indices.size != 0:
raise IndexError("cannot take non-empty indices from an empty axis")
res_shape = x.shape[:axis] + indices.shape + x.shape[axis + 1 :]
else:
if axis != 0:
Expand All @@ -130,19 +132,26 @@ def put(x, indices, vals, /, *, axis=None, mode="wrap"):
The array the values will be put into.
indices (usm_ndarray)
One-dimensional array of indices.
Note that if indices are not unique, a race
condition will result, and the value written to
``x`` will not be deterministic.
:py:func:`dpctl.tensor.unique` can be used to
guarantee unique elements in ``indices``.
vals:
Array of values to be put into `x`.
Must be broadcastable to the shape of `indices`.
Array of values to be put into ``x``.
Must be broadcastable to the result shape
``x.shape[:axis] + indices.shape + x.shape[axis+1:]``.
axis:
The axis over which the values will be placed.
If x is one-dimensional, this argument is optional.
Default: `None`.
The axis along which the values will be placed.
If ``x`` is one-dimensional, this argument is optional.
Default: ``None``.
mode:
How out-of-bounds indices will be handled.
"wrap" - clamps indices to (-n <= i < n), then wraps
``"wrap"`` - clamps indices to (-n <= i < n), then wraps
negative indices.
"clip" - clips indices to (0 <= i < n)
Default: `"wrap"`.
``"clip"`` - clips indices to (0 <= i < n)
Default: ``"wrap"``.
"""
if not isinstance(x, dpt.usm_ndarray):
raise TypeError(
Expand All @@ -168,7 +177,7 @@ def put(x, indices, vals, /, *, axis=None, mode="wrap"):
raise ValueError(
"`indices` expected a 1D array, got `{}`".format(indices.ndim)
)
if not np.issubdtype(indices.dtype, np.integer):
if indices.dtype.kind not in "ui":
raise IndexError(
"`indices` expected integer data type, got `{}`".format(
indices.dtype
Expand All @@ -195,7 +204,9 @@ def put(x, indices, vals, /, *, axis=None, mode="wrap"):

if x_ndim > 0:
axis = normalize_axis_index(operator.index(axis), x_ndim)

x_sh = x.shape
if x_sh[axis] == 0 and indices.size != 0:
raise IndexError("cannot take non-empty indices from an empty axis")
val_shape = x.shape[:axis] + indices.shape + x.shape[axis + 1 :]
else:
if axis != 0:
Expand All @@ -206,10 +217,18 @@ def put(x, indices, vals, /, *, axis=None, mode="wrap"):
vals = dpt.asarray(
vals, dtype=x.dtype, usm_type=vals_usm_type, sycl_queue=exec_q
)
# choose to throw here for consistency with `place`
if vals.size == 0:
raise ValueError(
"cannot put into non-empty indices along an empty axis"
)
if vals.dtype == x.dtype:
rhs = vals
else:
rhs = dpt.astype(vals, x.dtype)
rhs = dpt.broadcast_to(rhs, val_shape)

vals = dpt.broadcast_to(vals, val_shape)

hev, _ = ti._put(x, (indices,), vals, axis, mode, sycl_queue=exec_q)
hev, _ = ti._put(x, (indices,), rhs, axis, mode, sycl_queue=exec_q)
hev.wait()


Expand Down
8 changes: 8 additions & 0 deletions dpctl/tensor/libtensor/source/integer_advanced_indexing.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -414,6 +414,10 @@ usm_ndarray_take(const dpctl::tensor::usm_ndarray &src,
ind_offsets.push_back(py::ssize_t(0));
}

if (ind_nelems == 0) {
return std::make_pair(sycl::event{}, sycl::event{});
}

char **packed_ind_ptrs = sycl::malloc_device<char *>(k, exec_q);

if (packed_ind_ptrs == nullptr) {
Expand Down Expand Up @@ -717,6 +721,10 @@ usm_ndarray_put(const dpctl::tensor::usm_ndarray &dst,
ind_offsets.push_back(py::ssize_t(0));
}

if (ind_nelems == 0) {
return std::make_pair(sycl::event{}, sycl::event{});
}

char **packed_ind_ptrs = sycl::malloc_device<char *>(k, exec_q);

if (packed_ind_ptrs == nullptr) {
Expand Down
Loading

0 comments on commit 7757857

Please sign in to comment.