diff --git a/dpctl/tensor/CMakeLists.txt b/dpctl/tensor/CMakeLists.txt index 9c53bd8b08..b5a66b88f6 100644 --- a/dpctl/tensor/CMakeLists.txt +++ b/dpctl/tensor/CMakeLists.txt @@ -112,6 +112,7 @@ set(_reduction_sources ${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/reductions/sum.cpp ) set(_sorting_sources + ${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/sorting/isin.cpp ${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/sorting/merge_sort.cpp ${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/sorting/merge_argsort.cpp ${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/sorting/radix_sort.cpp diff --git a/dpctl/tensor/__init__.py b/dpctl/tensor/__init__.py index ea60fa0eb7..df10879b68 100644 --- a/dpctl/tensor/__init__.py +++ b/dpctl/tensor/__init__.py @@ -198,6 +198,7 @@ ) from ._searchsorted import searchsorted from ._set_functions import ( + isin, unique_all, unique_counts, unique_inverse, @@ -394,4 +395,5 @@ "top_k", "dldevice_to_sycl_device", "sycl_device_to_dldevice", + "isin", ] diff --git a/dpctl/tensor/_clip.py b/dpctl/tensor/_clip.py index e00036a94e..250f116927 100644 --- a/dpctl/tensor/_clip.py +++ b/dpctl/tensor/_clip.py @@ -23,16 +23,16 @@ _empty_like_pair_orderK, _empty_like_triple_orderK, ) -from dpctl.tensor._elementwise_common import ( +from dpctl.tensor._manipulation_functions import _broadcast_shape_impl +from dpctl.tensor._type_utils import _can_cast +from dpctl.utils import ExecutionPlacementError, SequentialOrderManager + +from ._scalar_utils import ( _get_dtype, _get_queue_usm_type, _get_shape, _validate_dtype, ) -from dpctl.tensor._manipulation_functions import _broadcast_shape_impl -from dpctl.tensor._type_utils import _can_cast -from dpctl.utils import ExecutionPlacementError, SequentialOrderManager - from ._type_utils import ( _resolve_one_strong_one_weak_types, _resolve_one_strong_two_weak_types, diff --git a/dpctl/tensor/_elementwise_common.py b/dpctl/tensor/_elementwise_common.py index ddfae7155a..491ef75c56 100644 --- a/dpctl/tensor/_elementwise_common.py +++ b/dpctl/tensor/_elementwise_common.py @@ -14,24 +14,20 @@ # See the License for the specific language governing permissions and # limitations under the License. -import numbers - -import numpy as np - import dpctl -import dpctl.memory as dpm import dpctl.tensor as dpt import dpctl.tensor._tensor_impl as ti from dpctl.tensor._manipulation_functions import _broadcast_shape_impl -from dpctl.tensor._usmarray import _is_object_with_buffer_protocol as _is_buffer from dpctl.utils import ExecutionPlacementError, SequentialOrderManager from ._copy_utils import _empty_like_orderK, _empty_like_pair_orderK +from ._scalar_utils import ( + _get_dtype, + _get_queue_usm_type, + _get_shape, + _validate_dtype, +) from ._type_utils import ( - WeakBooleanType, - WeakComplexType, - WeakFloatingType, - WeakIntegralType, _acceptance_fn_default_binary, _acceptance_fn_default_unary, _all_data_types, @@ -39,7 +35,6 @@ _find_buf_dtype2, _find_buf_dtype_in_place_op, _resolve_weak_types, - _to_device_supported_dtype, ) @@ -289,78 +284,6 @@ def __call__(self, x, /, *, out=None, order="K"): return out -def _get_queue_usm_type(o): - """Return SYCL device where object `o` allocated memory, or None.""" - if isinstance(o, dpt.usm_ndarray): - return o.sycl_queue, o.usm_type - elif hasattr(o, "__sycl_usm_array_interface__"): - try: - m = dpm.as_usm_memory(o) - return m.sycl_queue, m.get_usm_type() - except Exception: - return None, None - return None, None - - -def _get_dtype(o, dev): - if isinstance(o, dpt.usm_ndarray): - return o.dtype - if hasattr(o, "__sycl_usm_array_interface__"): - return dpt.asarray(o).dtype - if _is_buffer(o): - host_dt = np.array(o).dtype - dev_dt = _to_device_supported_dtype(host_dt, dev) - return dev_dt - if hasattr(o, "dtype"): - dev_dt = _to_device_supported_dtype(o.dtype, dev) - return dev_dt - if isinstance(o, bool): - return WeakBooleanType(o) - if isinstance(o, int): - return WeakIntegralType(o) - if isinstance(o, float): - return WeakFloatingType(o) - if isinstance(o, complex): - return WeakComplexType(o) - return np.object_ - - -def _validate_dtype(dt) -> bool: - return isinstance( - dt, - (WeakBooleanType, WeakIntegralType, WeakFloatingType, WeakComplexType), - ) or ( - isinstance(dt, dpt.dtype) - and dt - in [ - dpt.bool, - dpt.int8, - dpt.uint8, - dpt.int16, - dpt.uint16, - dpt.int32, - dpt.uint32, - dpt.int64, - dpt.uint64, - dpt.float16, - dpt.float32, - dpt.float64, - dpt.complex64, - dpt.complex128, - ] - ) - - -def _get_shape(o): - if isinstance(o, dpt.usm_ndarray): - return o.shape - if _is_buffer(o): - return memoryview(o).shape - if isinstance(o, numbers.Number): - return tuple() - return getattr(o, "shape", tuple()) - - class BinaryElementwiseFunc: """ Class that implements binary element-wise functions. diff --git a/dpctl/tensor/_scalar_utils.py b/dpctl/tensor/_scalar_utils.py new file mode 100644 index 0000000000..8b6aa01c86 --- /dev/null +++ b/dpctl/tensor/_scalar_utils.py @@ -0,0 +1,111 @@ +# Data Parallel Control (dpctl) +# +# Copyright 2020-2025 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numbers + +import numpy as np + +import dpctl.memory as dpm +import dpctl.tensor as dpt +from dpctl.tensor._usmarray import _is_object_with_buffer_protocol as _is_buffer + +from ._type_utils import ( + WeakBooleanType, + WeakComplexType, + WeakFloatingType, + WeakIntegralType, + _to_device_supported_dtype, +) + + +def _get_queue_usm_type(o): + """Return SYCL device where object `o` allocated memory, or None.""" + if isinstance(o, dpt.usm_ndarray): + return o.sycl_queue, o.usm_type + elif hasattr(o, "__sycl_usm_array_interface__"): + try: + m = dpm.as_usm_memory(o) + return m.sycl_queue, m.get_usm_type() + except Exception: + return None, None + return None, None + + +def _get_dtype(o, dev): + if isinstance(o, dpt.usm_ndarray): + return o.dtype + if hasattr(o, "__sycl_usm_array_interface__"): + return dpt.asarray(o).dtype + if _is_buffer(o): + host_dt = np.array(o).dtype + dev_dt = _to_device_supported_dtype(host_dt, dev) + return dev_dt + if hasattr(o, "dtype"): + dev_dt = _to_device_supported_dtype(o.dtype, dev) + return dev_dt + if isinstance(o, bool): + return WeakBooleanType(o) + if isinstance(o, int): + return WeakIntegralType(o) + if isinstance(o, float): + return WeakFloatingType(o) + if isinstance(o, complex): + return WeakComplexType(o) + return np.object_ + + +def _validate_dtype(dt) -> bool: + return isinstance( + dt, + (WeakBooleanType, WeakIntegralType, WeakFloatingType, WeakComplexType), + ) or ( + isinstance(dt, dpt.dtype) + and dt + in [ + dpt.bool, + dpt.int8, + dpt.uint8, + dpt.int16, + dpt.uint16, + dpt.int32, + dpt.uint32, + dpt.int64, + dpt.uint64, + dpt.float16, + dpt.float32, + dpt.float64, + dpt.complex64, + dpt.complex128, + ] + ) + + +def _get_shape(o): + if isinstance(o, dpt.usm_ndarray): + return o.shape + if _is_buffer(o): + return memoryview(o).shape + if isinstance(o, numbers.Number): + return tuple() + return getattr(o, "shape", tuple()) + + +__all__ = [ + "_get_dtype", + "_get_queue_usm_type", + "_get_shape", + "_validate_dtype", +] diff --git a/dpctl/tensor/_search_functions.py b/dpctl/tensor/_search_functions.py index a09647ae2e..e09535bd3a 100644 --- a/dpctl/tensor/_search_functions.py +++ b/dpctl/tensor/_search_functions.py @@ -17,16 +17,16 @@ import dpctl import dpctl.tensor as dpt import dpctl.tensor._tensor_impl as ti -from dpctl.tensor._elementwise_common import ( +from dpctl.tensor._manipulation_functions import _broadcast_shape_impl +from dpctl.utils import ExecutionPlacementError, SequentialOrderManager + +from ._copy_utils import _empty_like_orderK, _empty_like_triple_orderK +from ._scalar_utils import ( _get_dtype, _get_queue_usm_type, _get_shape, _validate_dtype, ) -from dpctl.tensor._manipulation_functions import _broadcast_shape_impl -from dpctl.utils import ExecutionPlacementError, SequentialOrderManager - -from ._copy_utils import _empty_like_orderK, _empty_like_triple_orderK from ._type_utils import ( WeakBooleanType, WeakComplexType, diff --git a/dpctl/tensor/_set_functions.py b/dpctl/tensor/_set_functions.py index d23bc68bd5..adbe371002 100644 --- a/dpctl/tensor/_set_functions.py +++ b/dpctl/tensor/_set_functions.py @@ -16,9 +16,17 @@ from typing import NamedTuple +import dpctl import dpctl.tensor as dpt import dpctl.utils as du +from ._copy_utils import _empty_like_orderK +from ._scalar_utils import ( + _get_dtype, + _get_queue_usm_type, + _get_shape, + _validate_dtype, +) from ._tensor_elementwise_impl import _not_equal, _subtract from ._tensor_impl import ( _copy_usm_ndarray_into_usm_ndarray, @@ -31,11 +39,17 @@ ) from ._tensor_sorting_impl import ( _argsort_ascending, + _isin, _searchsorted_left, _sort_ascending, ) +from ._type_utils import ( + _resolve_weak_types_all_py_ints, + _to_device_supported_dtype, +) __all__ = [ + "isin", "unique_values", "unique_counts", "unique_inverse", @@ -624,3 +638,130 @@ def unique_all(x: dpt.usm_ndarray) -> UniqueAllResult: inv, _counts, ) + + +def isin(x, test_elements, /, *, invert=False): + """ + Tests `x in test_elements` for each element of `x`. Returns a boolean array + with the same shape as `x` that is `True` where the element is in + `test_elements`, `False` otherwise. + + Args: + x (Union[usm_ndarray, bool, int, float, complex]): + input element or elements. + test_elements (Union[usm_ndarray, bool, int, float, complex]): + elements against which to test each value of `x`. + invert (Optional[bool]): + if `True`, the output results are inverted, i.e., are equivalent to + testing `x not in test_elements` for each element of `x`. + Default: `False`. + + Returns: + usm_ndarray: + an array of the inclusion test results. The returned array has a + boolean data type and the same shape as `x`. + """ + q1, x_usm_type = _get_queue_usm_type(x) + q2, test_usm_type = _get_queue_usm_type(test_elements) + if q1 is None and q2 is None: + raise du.ExecutionPlacementError( + "Execution placement can not be unambiguously inferred " + "from input arguments. " + "One of the arguments must represent USM allocation and " + "expose `__sycl_usm_array_interface__` property" + ) + if q1 is None: + exec_q = q2 + res_usm_type = test_usm_type + elif q2 is None: + exec_q = q1 + res_usm_type = x_usm_type + else: + exec_q = dpctl.utils.get_execution_queue((q1, q2)) + if exec_q is None: + raise du.ExecutionPlacementError( + "Execution placement can not be unambiguously inferred " + "from input arguments." + ) + res_usm_type = dpctl.utils.get_coerced_usm_type( + ( + x_usm_type, + test_usm_type, + ) + ) + dpctl.utils.validate_usm_type(res_usm_type, allow_none=False) + sycl_dev = exec_q.sycl_device + + x_dt = _get_dtype(x, sycl_dev) + test_dt = _get_dtype(test_elements, sycl_dev) + if not all(_validate_dtype(dt) for dt in (x_dt, test_dt)): + raise ValueError("Operands have unsupported data types") + + x_sh = _get_shape(x) + if isinstance(test_elements, dpt.usm_ndarray) and test_elements.size == 0: + if invert: + return dpt.ones( + x_sh, dtype=dpt.bool, usm_type=res_usm_type, sycl_queue=exec_q + ) + else: + return dpt.zeros( + x_sh, dtype=dpt.bool, usm_type=res_usm_type, sycl_queue=exec_q + ) + + dt1, dt2 = _resolve_weak_types_all_py_ints(x_dt, test_dt, sycl_dev) + dt = _to_device_supported_dtype(dpt.result_type(dt1, dt2), sycl_dev) + + if not isinstance(x, dpt.usm_ndarray): + x_arr = dpt.asarray( + x, dtype=dt1, usm_type=res_usm_type, sycl_queue=exec_q + ) + else: + x_arr = x + + if not isinstance(test_elements, dpt.usm_ndarray): + test_arr = dpt.asarray( + test_elements, dtype=dt2, usm_type=res_usm_type, sycl_queue=exec_q + ) + else: + test_arr = test_elements + + _manager = du.SequentialOrderManager[exec_q] + dep_evs = _manager.submitted_events + + if x_dt != dt: + x_buf = _empty_like_orderK(x_arr, dt, res_usm_type, sycl_dev) + ht_ev, ev = _copy_usm_ndarray_into_usm_ndarray( + src=x_arr, dst=x_buf, sycl_queue=exec_q, depends=dep_evs + ) + _manager.add_event_pair(ht_ev, ev) + else: + x_buf = x_arr + + if test_dt != dt: + # copy into C-contiguous memory, because the array will be flattened + test_buf = dpt.empty_like(test_arr, dtype=dt, order="C") + ht_ev, ev = _copy_usm_ndarray_into_usm_ndarray( + src=test_arr, dst=test_buf, sycl_queue=exec_q, depends=dep_evs + ) + _manager.add_event_pair(ht_ev, ev) + else: + test_buf = test_arr + + test_buf = dpt.reshape(test_buf, -1) + test_buf = dpt.sort(test_buf) + + dst = dpt.empty_like( + x_buf, dtype=dpt.bool, usm_type=res_usm_type, order="C" + ) + + dep_evs = _manager.submitted_events + ht_ev, s_ev = _isin( + needles=x_buf, + hay=test_buf, + dst=dst, + sycl_queue=exec_q, + invert=invert, + depends=dep_evs, + ) + _manager.add_event_pair(ht_ev, s_ev) + return dst diff --git a/dpctl/tensor/_utility_functions.py b/dpctl/tensor/_utility_functions.py index 3ac6c40546..3ccf283dbf 100644 --- a/dpctl/tensor/_utility_functions.py +++ b/dpctl/tensor/_utility_functions.py @@ -21,14 +21,14 @@ import dpctl.tensor._tensor_impl as ti import dpctl.tensor._tensor_reductions_impl as tri import dpctl.utils as du -from dpctl.tensor._elementwise_common import ( + +from ._numpy_helper import normalize_axis_index, normalize_axis_tuple +from ._scalar_utils import ( _get_dtype, _get_queue_usm_type, _get_shape, _validate_dtype, ) - -from ._numpy_helper import normalize_axis_index, normalize_axis_tuple from ._type_utils import ( _resolve_one_strong_one_weak_types, _resolve_one_strong_two_weak_types, diff --git a/dpctl/tensor/libtensor/include/kernels/sorting/isin.hpp b/dpctl/tensor/libtensor/include/kernels/sorting/isin.hpp new file mode 100644 index 0000000000..3e88188253 --- /dev/null +++ b/dpctl/tensor/libtensor/include/kernels/sorting/isin.hpp @@ -0,0 +1,240 @@ +//=== isin.hpp - ---*-C++-*--/===// +// Implementation of searching for membership in sorted array +// +// Data Parallel Control (dpctl) +// +// Copyright 2020-2025 Intel Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//===----------------------------------------------------------------------===// +/// +/// \file +/// This file defines kernels for tensor membership operations. +//===----------------------------------------------------------------------===// + +#pragma once + +#include +#include +#include +#include +#include + +#include "kernels/dpctl_tensor_types.hpp" +#include "kernels/sorting/search_sorted_detail.hpp" +#include "utils/offset_utils.hpp" + +namespace dpctl +{ +namespace tensor +{ +namespace kernels +{ + +using dpctl::tensor::ssize_t; + +template +struct IsinFunctor +{ +private: + bool invert; + const T *hay_tp; + const T *needles_tp; + bool *out_tp; + std::size_t hay_nelems; + HayIndexerT hay_indexer; + NeedlesIndexerT needles_indexer; + OutIndexerT out_indexer; + +public: + IsinFunctor(const bool invert_, + const T *hay_, + const T *needles_, + bool *out_, + const std::size_t hay_nelems_, + const HayIndexerT &hay_indexer_, + const NeedlesIndexerT &needles_indexer_, + const OutIndexerT &out_indexer_) + : invert(invert_), hay_tp(hay_), needles_tp(needles_), out_tp(out_), + hay_nelems(hay_nelems_), hay_indexer(hay_indexer_), + needles_indexer(needles_indexer_), out_indexer(out_indexer_) + { + } + + void operator()(sycl::id<1> id) const + { + static constexpr Compare comp{}; + + const std::size_t i = id[0]; + const T needle_v = needles_tp[needles_indexer(i)]; + + // position of the needle_v in the hay array + std::size_t pos{}; + + static constexpr std::size_t zero(0); + // search in hay in left-closed interval, give `pos` such that + // hay[pos - 1] < needle_v <= hay[pos] + + // lower_bound returns the first pos such that bool(hay[pos] < + // needle_v) is false, i.e. needle_v <= hay[pos] + pos = static_cast( + search_sorted_detail::lower_bound_indexed_impl( + hay_tp, zero, hay_nelems, needle_v, comp, hay_indexer)); + bool out = (pos == hay_nelems ? false : hay_tp[pos] == needle_v); + out_tp[out_indexer(i)] = (invert) ? !out : out; + } +}; + +typedef sycl::event (*isin_contig_impl_fp_ptr_t)( + sycl::queue &, + const bool, + const std::size_t, + const std::size_t, + const char *, + const ssize_t, + const char *, + const ssize_t, + char *, + const ssize_t, + const std::vector &); + +template class isin_contig_impl_krn; + +template +sycl::event isin_contig_impl(sycl::queue &exec_q, + const bool invert, + const std::size_t hay_nelems, + const std::size_t needles_nelems, + const char *hay_cp, + const ssize_t hay_offset, + const char *needles_cp, + const ssize_t needles_offset, + char *out_cp, + const ssize_t out_offset, + const std::vector &depends) +{ + const T *hay_tp = reinterpret_cast(hay_cp) + hay_offset; + const T *needles_tp = + reinterpret_cast(needles_cp) + needles_offset; + + bool *out_tp = reinterpret_cast(out_cp) + out_offset; + + sycl::event comp_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(depends); + + using KernelName = class isin_contig_impl_krn; + + sycl::range<1> gRange(needles_nelems); + + using TrivialIndexerT = dpctl::tensor::offset_utils::NoOpIndexer; + + static constexpr TrivialIndexerT hay_indexer{}; + static constexpr TrivialIndexerT needles_indexer{}; + static constexpr TrivialIndexerT out_indexer{}; + + const auto fnctr = + IsinFunctor(invert, hay_tp, needles_tp, out_tp, hay_nelems, + hay_indexer, needles_indexer, out_indexer); + + cgh.parallel_for(gRange, fnctr); + }); + + return comp_ev; +} + +typedef sycl::event (*isin_strided_impl_fp_ptr_t)( + sycl::queue &, + const bool, + const std::size_t, + const std::size_t, + const char *, + const ssize_t, + const ssize_t, + const char *, + const ssize_t, + char *, + const ssize_t, + int, + const ssize_t *, + const std::vector &); + +template class isin_strided_impl_krn; + +template +sycl::event isin_strided_impl( + sycl::queue &exec_q, + const bool invert, + const std::size_t hay_nelems, + const std::size_t needles_nelems, + const char *hay_cp, + const ssize_t hay_offset, + // hay is 1D, so hay_nelems, hay_offset, hay_stride describe strided array + const ssize_t hay_stride, + const char *needles_cp, + const ssize_t needles_offset, + char *out_cp, + const ssize_t out_offset, + const int needles_nd, + // packed_shape_strides is [needles_shape, needles_strides, + // out_strides] has length of 3*needles_nd + const ssize_t *packed_shape_strides, + const std::vector &depends) +{ + const T *hay_tp = reinterpret_cast(hay_cp); + const T *needles_tp = reinterpret_cast(needles_cp); + + bool *out_tp = reinterpret_cast(out_cp); + + sycl::event comp_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(depends); + + sycl::range<1> gRange(needles_nelems); + + using HayIndexerT = dpctl::tensor::offset_utils::Strided1DIndexer; + const HayIndexerT hay_indexer( + /* offset */ hay_offset, + /* size */ hay_nelems, + /* step */ hay_stride); + + using NeedlesIndexerT = dpctl::tensor::offset_utils::StridedIndexer; + const ssize_t *needles_shape_strides = packed_shape_strides; + const NeedlesIndexerT needles_indexer(needles_nd, needles_offset, + needles_shape_strides); + using OutIndexerT = dpctl::tensor::offset_utils::UnpackedStridedIndexer; + + const ssize_t *out_shape = packed_shape_strides; + const ssize_t *out_strides = packed_shape_strides + 2 * needles_nd; + const OutIndexerT out_indexer(needles_nd, out_offset, out_shape, + out_strides); + + const auto fnctr = + IsinFunctor( + invert, hay_tp, needles_tp, out_tp, hay_nelems, hay_indexer, + needles_indexer, out_indexer); + using KernelName = class isin_strided_impl_krn; + + cgh.parallel_for(gRange, fnctr); + }); + + return comp_ev; +} + +} // namespace kernels +} // namespace tensor +} // namespace dpctl diff --git a/dpctl/tensor/libtensor/source/sorting/isin.cpp b/dpctl/tensor/libtensor/source/sorting/isin.cpp new file mode 100644 index 0000000000..4eb825dae7 --- /dev/null +++ b/dpctl/tensor/libtensor/source/sorting/isin.cpp @@ -0,0 +1,322 @@ +// +// Data Parallel Control (dpctl) +// +// Copyright 2020-2025 Intel Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//===--------------------------------------------------------------------===// +/// +/// \file +/// This file defines functions of dpctl.tensor._tensor_sorting_impl +/// extension. +//===--------------------------------------------------------------------===// + +#include +#include +#include +#include +#include + +#include "dpctl4pybind11.hpp" +#include +#include + +#include "kernels/sorting/isin.hpp" +#include "utils/memory_overlap.hpp" +#include "utils/output_validation.hpp" +#include "utils/sycl_alloc_utils.hpp" +#include "utils/type_dispatch.hpp" +#include "utils/type_utils.hpp" + +#include "rich_comparisons.hpp" +#include "simplify_iteration_space.hpp" + +namespace py = pybind11; +namespace td_ns = dpctl::tensor::type_dispatch; +namespace tu_ns = dpctl::tensor::type_utils; + +namespace dpctl +{ +namespace tensor +{ +namespace py_internal +{ + +namespace detail +{ + +using dpctl::tensor::kernels::isin_contig_impl_fp_ptr_t; + +static isin_contig_impl_fp_ptr_t + isin_contig_impl_dispatch_vector[td_ns::num_types]; + +template struct IsinContigFactory +{ + constexpr IsinContigFactory() {} + + fnT get() const + { + using dpctl::tensor::kernels::isin_contig_impl; + using Compare = typename AscendingSorter::type; + return isin_contig_impl; + } +}; + +using dpctl::tensor::kernels::isin_strided_impl_fp_ptr_t; + +static isin_strided_impl_fp_ptr_t + isin_strided_impl_dispatch_vector[td_ns::num_types]; + +template struct IsinStridedFactory +{ + constexpr IsinStridedFactory() {} + + fnT get() const + { + using dpctl::tensor::kernels::isin_strided_impl; + using Compare = typename AscendingSorter::type; + return isin_strided_impl; + } +}; + +void init_isin_dispatch_vector(void) +{ + + // Contiguous input function dispatch + td_ns::DispatchVectorBuilder + dvb1; + dvb1.populate_dispatch_vector(isin_contig_impl_dispatch_vector); + + // Strided input function dispatch + td_ns::DispatchVectorBuilder + dvb2; + dvb2.populate_dispatch_vector(isin_strided_impl_dispatch_vector); +} + +} // namespace detail + +/*! @brief search for needle from needles in sorted hay */ +std::pair +py_isin(const dpctl::tensor::usm_ndarray &needles, + const dpctl::tensor::usm_ndarray &hay, + const dpctl::tensor::usm_ndarray &dst, + sycl::queue &exec_q, + const bool invert, + const std::vector &depends) +{ + const int hay_nd = hay.get_ndim(); + const int needles_nd = needles.get_ndim(); + const int dst_nd = dst.get_ndim(); + + if (hay_nd != 1 || needles_nd != dst_nd) { + throw py::value_error("Array dimensions mismatch"); + } + + // check that needle and dst have the same shape + std::size_t needles_nelems(1); + bool same_shape(true); + + const std::size_t hay_nelems = static_cast(hay.get_shape(0)); + + const py::ssize_t *needles_shape_ptr = needles.get_shape_raw(); + const py::ssize_t *dst_shape_ptr = needles.get_shape_raw(); + + for (int i = 0; (i < needles_nd) && same_shape; ++i) { + const auto needles_sh_i = needles_shape_ptr[i]; + const auto dst_sh_i = dst_shape_ptr[i]; + + same_shape = same_shape && (needles_sh_i == dst_sh_i); + needles_nelems *= static_cast(needles_sh_i); + } + + if (!same_shape) { + throw py::value_error( + "Array of values to search for and array of their " + "dst do not have the same shape"); + } + + // check that dst is ample enough + dpctl::tensor::validation::AmpleMemory::throw_if_not_ample(dst, + needles_nelems); + + // check that dst is writable + dpctl::tensor::validation::CheckWritable::throw_if_not_writable(dst); + + // check that queues are compatible + if (!dpctl::utils::queues_are_compatible(exec_q, {hay, needles, dst})) { + throw py::value_error( + "Execution queue is not compatible with allocation queues"); + } + + // if output array overlaps with input arrays, race condition results + auto const &overlap = dpctl::tensor::overlap::MemoryOverlap(); + if (overlap(dst, hay) || overlap(dst, needles)) { + throw py::value_error("Destination array overlaps with input."); + } + + const int hay_typenum = hay.get_typenum(); + const int needles_typenum = needles.get_typenum(); + const int dst_typenum = dst.get_typenum(); + + auto const &array_types = td_ns::usm_ndarray_types(); + const int hay_typeid = array_types.typenum_to_lookup_id(hay_typenum); + const int needles_typeid = + array_types.typenum_to_lookup_id(needles_typenum); + const int dst_typeid = array_types.typenum_to_lookup_id(dst_typenum); + + // check hay and needle have the same data-type + if (needles_typeid != hay_typeid) { + throw py::value_error( + "Hay array and needles array must have the same data types"); + } + // check that dst has boolean data type + const auto dst_typenum_t_v = static_cast(dst_typeid); + if (dst_typenum_t_v != td_ns::typenum_t::BOOL) { + throw py::value_error("dst array must have data-type bool"); + } + + if (needles_nelems == 0) { + // Nothing to do + return std::make_pair(sycl::event{}, sycl::event{}); + } + + // if all inputs are contiguous call contiguous implementations + // otherwise call strided implementation + const bool hay_is_c_contig = hay.is_c_contiguous(); + const bool hay_is_f_contig = hay.is_f_contiguous(); + + const bool needles_is_c_contig = needles.is_c_contiguous(); + const bool needles_is_f_contig = needles.is_f_contiguous(); + + const bool dst_is_c_contig = dst.is_c_contiguous(); + const bool dst_is_f_contig = dst.is_f_contiguous(); + + const bool all_c_contig = + (hay_is_c_contig && needles_is_c_contig && dst_is_c_contig); + const bool all_f_contig = + (hay_is_f_contig && needles_is_f_contig && dst_is_f_contig); + + const char *hay_data = hay.get_data(); + const char *needles_data = needles.get_data(); + + char *dst_data = dst.get_data(); + + if (all_c_contig || all_f_contig) { + auto fn = detail::isin_contig_impl_dispatch_vector[hay_typeid]; + + static constexpr py::ssize_t zero_offset(0); + + sycl::event comp_ev = fn(exec_q, invert, hay_nelems, needles_nelems, + hay_data, zero_offset, needles_data, + zero_offset, dst_data, zero_offset, depends); + + return std::make_pair(dpctl::utils::keep_args_alive( + exec_q, {hay, needles, dst}, {comp_ev}), + comp_ev); + } + + // strided case + + const auto &needles_strides = needles.get_strides_vector(); + const auto &dst_strides = dst.get_strides_vector(); + + int simplified_nd = needles_nd; + + using shT = std::vector; + shT simplified_common_shape; + shT simplified_needles_strides; + shT simplified_dst_strides; + py::ssize_t needles_offset(0); + py::ssize_t dst_offset(0); + + if (simplified_nd == 0) { + // needles and dst have same nd + simplified_nd = 1; + simplified_common_shape.push_back(1); + simplified_needles_strides.push_back(0); + simplified_dst_strides.push_back(0); + } + else { + dpctl::tensor::py_internal::simplify_iteration_space( + // modified by refernce + simplified_nd, + // read-only inputs + needles_shape_ptr, needles_strides, dst_strides, + // output, modified by reference + simplified_common_shape, simplified_needles_strides, + simplified_dst_strides, needles_offset, dst_offset); + } + std::vector host_task_events; + host_task_events.reserve(2); + + using dpctl::tensor::offset_utils::device_allocate_and_pack; + + auto ptr_size_event_tuple = device_allocate_and_pack( + exec_q, host_task_events, + // vectors being packed + simplified_common_shape, simplified_needles_strides, + simplified_dst_strides); + auto packed_shape_strides_owner = + std::move(std::get<0>(ptr_size_event_tuple)); + const sycl::event ©_shape_strides_ev = + std::get<2>(ptr_size_event_tuple); + const py::ssize_t *packed_shape_strides = packed_shape_strides_owner.get(); + + std::vector all_deps; + all_deps.reserve(depends.size() + 1); + all_deps.insert(all_deps.end(), depends.begin(), depends.end()); + all_deps.push_back(copy_shape_strides_ev); + + auto strided_fn = detail::isin_strided_impl_dispatch_vector[hay_typeid]; + + if (!strided_fn) { + throw std::runtime_error( + "No implementation for data types of input arrays"); + } + + static constexpr py::ssize_t zero_offset(0); + py::ssize_t hay_step = hay.get_strides_vector()[0]; + + const sycl::event &comp_ev = strided_fn( + exec_q, invert, hay_nelems, needles_nelems, hay_data, zero_offset, + hay_step, needles_data, needles_offset, dst_data, dst_offset, + simplified_nd, packed_shape_strides, all_deps); + + // free packed temporaries + sycl::event temporaries_cleanup_ev = + dpctl::tensor::alloc_utils::async_smart_free( + exec_q, {comp_ev}, packed_shape_strides_owner); + + host_task_events.push_back(temporaries_cleanup_ev); + const sycl::event &ht_ev = dpctl::utils::keep_args_alive( + exec_q, {hay, needles, dst}, host_task_events); + + return std::make_pair(ht_ev, comp_ev); +} + +void init_isin_functions(py::module_ m) +{ + dpctl::tensor::py_internal::detail::init_isin_dispatch_vector(); + + using dpctl::tensor::py_internal::py_isin; + m.def("_isin", &py_isin, py::arg("needles"), py::arg("hay"), py::arg("dst"), + py::arg("sycl_queue"), py::arg("invert"), + py::arg("depends") = py::list()); +} + +} // end of namespace py_internal +} // end of namespace tensor +} // namespace dpctl diff --git a/dpctl/tensor/libtensor/source/sorting/isin.hpp b/dpctl/tensor/libtensor/source/sorting/isin.hpp new file mode 100644 index 0000000000..c855cd3d4c --- /dev/null +++ b/dpctl/tensor/libtensor/source/sorting/isin.hpp @@ -0,0 +1,42 @@ +// +// Data Parallel Control (dpctl) +// +// Copyright 2020-2025 Intel Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//===--------------------------------------------------------------------===// +/// +/// \file +/// This file defines functions of dpctl.tensor._tensor_sorting_impl +/// extension. +//===--------------------------------------------------------------------===// + +#pragma once + +#include + +namespace py = pybind11; + +namespace dpctl +{ +namespace tensor +{ +namespace py_internal +{ + +extern void init_isin_functions(py::module_ m); + +} // namespace py_internal +} // namespace tensor +} // namespace dpctl diff --git a/dpctl/tensor/libtensor/source/tensor_sorting.cpp b/dpctl/tensor/libtensor/source/tensor_sorting.cpp index 1a264c780b..3b9bed7768 100644 --- a/dpctl/tensor/libtensor/source/tensor_sorting.cpp +++ b/dpctl/tensor/libtensor/source/tensor_sorting.cpp @@ -25,6 +25,7 @@ #include +#include "sorting/isin.hpp" #include "sorting/merge_argsort.hpp" #include "sorting/merge_sort.hpp" #include "sorting/radix_argsort.hpp" @@ -36,6 +37,7 @@ namespace py = pybind11; PYBIND11_MODULE(_tensor_sorting_impl, m) { + dpctl::tensor::py_internal::init_isin_functions(m); dpctl::tensor::py_internal::init_merge_sort_functions(m); dpctl::tensor::py_internal::init_merge_argsort_functions(m); dpctl::tensor::py_internal::init_searchsorted_functions(m); diff --git a/dpctl/tests/test_tensor_isin.py b/dpctl/tests/test_tensor_isin.py new file mode 100644 index 0000000000..f758653903 --- /dev/null +++ b/dpctl/tests/test_tensor_isin.py @@ -0,0 +1,168 @@ +# Data Parallel Control (dpctl) +# +# Copyright 2020-2025 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest + +import dpctl.tensor as dpt +from dpctl.tests.helper import get_queue_or_skip, skip_if_dtype_not_supported +from dpctl.utils import ExecutionPlacementError + + +@pytest.mark.parametrize( + "dtype", + [ + "i1", + "u1", + "i2", + "u2", + "i4", + "u4", + "i8", + "u8", + "f2", + "f4", + "f8", + "c8", + "c16", + ], +) +def test_isin_basic(dtype): + q = get_queue_or_skip() + skip_if_dtype_not_supported(dtype, q) + + n = 100 + x = dpt.arange(n, dtype=dtype) + test = dpt.arange(n - 1, dtype=dtype) + r1 = dpt.isin(x, test) + assert dpt.all(r1[:-1]) + assert not r1[-1] + assert r1.shape == x.shape + + # test with invert keyword + r2 = dpt.isin(x, test, invert=True) + assert not dpt.all(r2[:-1]) + assert r2[-1] + assert r2.shape == x.shape + + +def test_isin_basic_bool(): + dt = dpt.bool + n = 100 + x = dpt.zeros(n, dtype=dt) + x[-1] = True + test = dpt.zeros((), dtype=dt) + r1 = dpt.isin(x, test) + assert dpt.all(r1[:-1]) + assert not r1[-1] + assert r1.shape == x.shape + + r2 = dpt.isin(x, test, invert=True) + assert not dpt.all(r2[:-1]) + assert r2[-1] + assert r2.shape == x.shape + + +@pytest.mark.parametrize( + "dtype", + [ + "i1", + "u1", + "i2", + "u2", + "i4", + "u4", + "i8", + "u8", + "f2", + "f4", + "f8", + "c8", + "c16", + ], +) +def test_isin_strided(dtype): + q = get_queue_or_skip() + skip_if_dtype_not_supported(dtype, q) + + n, m = 100, 20 + x = dpt.zeros((n, m), dtype=dtype, order="F") + x[:, ::2] = dpt.arange(1, (m / 2) + 1, dtype=dtype) + test = dpt.arange(1, (m / 2) + 1, dtype=dtype) + r1 = dpt.isin(x, test) + assert dpt.all(r1[:, ::2]) + assert not dpt.all(r1[:, 1::2]) + assert r1.shape == x.shape + + # test with invert keyword + r2 = dpt.isin(x, test, invert=True) + assert not dpt.all(r2[:, ::2]) + assert dpt.all(r2[:, 1::2]) + assert r2.shape == x.shape + + +def test_isin_strided_bool(): + dt = dpt.bool + n, m = 100, 20 + x = dpt.ones((n, m), dtype=dt, order="F") + x[:, ::2] = False + test = dpt.zeros((), dtype=dt) + r1 = dpt.isin(x, test) + assert dpt.all(r1[:, ::2]) + assert not dpt.all(r1[:, 1::2]) + assert r1.shape == x.shape + + # test with invert keyword + r2 = dpt.isin(x, test, invert=True) + assert not dpt.all(r2[:, ::2]) + assert dpt.all(r2[:, 1::2]) + assert r2.shape == x.shape + + +def test_isin_empty_inputs(): + get_queue_or_skip() + + x = dpt.ones((10, 0, 1), dtype="i4") + test = dpt.ones((), dtype="i4") + res1 = dpt.isin(x, test) + assert isinstance(res1, dpt.usm_ndarray) + assert res1.size == 0 + assert res1.shape == x.shape + assert res1.dtype == dpt.bool + + res2 = dpt.isin(x, test, invert=True) + assert isinstance(res2, dpt.usm_ndarray) + assert res2.size == 0 + assert res2.shape == x.shape + assert res2.dtype == dpt.bool + + x = dpt.ones((3, 3), dtype="i4") + test = dpt.ones(0, dtype="i4") + res3 = dpt.isin(x, test) + assert isinstance(res3, dpt.usm_ndarray) + assert res3.shape == x.shape + assert res3.dtype == dpt.bool + assert not dpt.all(res3) + + res4 = dpt.isin(x, test, invert=True) + assert isinstance(res4, dpt.usm_ndarray) + assert res4.shape == x.shape + assert res4.dtype == dpt.bool + assert dpt.all(res4) + + +def test_isin_validation(): + with pytest.raises(ExecutionPlacementError): + dpt.isin(1, 1)