diff --git a/dpctl/tensor/_indexing_functions.py b/dpctl/tensor/_indexing_functions.py index 188f3cd291..374e1bb58f 100644 --- a/dpctl/tensor/_indexing_functions.py +++ b/dpctl/tensor/_indexing_functions.py @@ -120,7 +120,7 @@ def take(x, indices, /, *, axis=None, mode="wrap"): return res -def put(x, indices, vals, /, *, axis=None, mode="wrap"): +def put(x, indices, vals, /, *, axis=None, mode="wrap", unique_only=False): """put(x, indices, vals, axis=None, mode="wrap") Puts values of an array into another array @@ -144,6 +144,10 @@ def put(x, indices, vals, /, *, axis=None, mode="wrap"): negative indices. "clip" - clips indices to (0 <= i < n) Default: `"wrap"`. + unique_only: + If True, only unique indices will be used. + This will prevent a undefined behavior when indices repeat, + but may negatively impact performance. """ if not isinstance(x, dpt.usm_ndarray): raise TypeError( @@ -175,6 +179,22 @@ def put(x, indices, vals, /, *, axis=None, mode="wrap"): indices.dtype ) ) + if not isinstance(unique_only, bool): + raise TypeError("`unique_only` expected `bool`, got `{}`".format( + type(unique_only) + )) + + not_unique = False + if unique_only: + indices = dpt.flip(indices) + _, local_indices = np.unique(indices, return_index=True) + not_unique = len(indices) != len(local_indices) + + if not_unique: + local_indices = dpt.asarray(local_indices) + indices = dpt.take(indices, local_indices) + indices = dpt.flip(indices) + queues_.append(indices.sycl_queue) usm_types_.append(indices.usm_type) exec_q = dpctl.utils.get_execution_queue(queues_) @@ -207,7 +227,10 @@ def put(x, indices, vals, /, *, axis=None, mode="wrap"): vals = dpt.asarray( vals, dtype=x.dtype, usm_type=vals_usm_type, sycl_queue=exec_q ) - + if not_unique: + vals = dpt.flip(vals) + vals = dpt.take(vals, local_indices) + vals = dpt.flip(vals) vals = dpt.broadcast_to(vals, val_shape) hev, _ = ti._put(x, (indices,), vals, axis, mode, sycl_queue=exec_q)