From 03fd73794070b80f7a41a94045704e1ef8721188 Mon Sep 17 00:00:00 2001 From: ndgrigorian <46709016+ndgrigorian@users.noreply.github.com> Date: Wed, 25 Oct 2023 18:48:14 -0700 Subject: [PATCH] Implements ``dpctl.tensor.logsumexp`` and ``dpctl.tensor.reduce_hypot`` (#1446) * Implements logsumexp and reduce_hypot * Implements dedicated kernels for temp reductions over axes 1 and 0 in contiguous arrays * logsumexp and reduce_hypot no longer use atomics This change was made to improve the accuracy of these functions * Adds tests for reduce_hypot and logsumexp * Arithmetic reductions no longer use atomics for inexact types This change is intended to improve the numerical stability of sum and prod * Removed support of atomic reduction for max and min * Adds new tests for reductions * Split reductions into multiple source files * Remove unneccessary imports of reduction init functions * Added functions for querying reduction atomic support per type per function * Corrected ``min`` contig variant typo These variants were using ``sycl::maximum`` rather than ``sycl::minimum`` * Removes _tree_reduction_over_axis Use lambdas to ignore atomic-specific arguments to hypot and logsumexp dtype_supported functions * Always use atomic implementation for min/max if available For add/multiplies reductions, use tree reduction for FP types, real and complex, to get better round-off accumulation properties. * ``logaddexp`` implementation moved to math_utils Reduces code repetition between logsumexp and logaddexp --------- Co-authored-by: Oleksandr Pavlyk --- dpctl/tensor/CMakeLists.txt | 16 +- dpctl/tensor/__init__.py | 13 +- dpctl/tensor/_reduction.py | 159 +- .../elementwise_functions/logaddexp.hpp | 27 +- .../libtensor/include/kernels/reductions.hpp | 4556 ++++++++++++----- .../libtensor/include/utils/math_utils.hpp | 20 + .../libtensor/include/utils/sycl_utils.hpp | 40 + .../libtensor/source/reduction_over_axis.cpp | 514 -- .../libtensor/source/reduction_over_axis.hpp | 689 --- .../libtensor/source/reductions/argmax.cpp | 119 + .../libtensor/source/reductions/argmax.hpp | 41 + .../libtensor/source/reductions/argmin.cpp | 119 + .../libtensor/source/reductions/argmin.hpp | 41 + .../libtensor/source/reductions/logsumexp.cpp | 136 + .../libtensor/source/reductions/logsumexp.hpp | 41 + .../libtensor/source/reductions/max.cpp | 171 + .../libtensor/source/reductions/max.hpp | 41 + .../libtensor/source/reductions/min.cpp | 173 + .../libtensor/source/reductions/min.hpp | 41 + .../libtensor/source/reductions/prod.cpp | 187 + .../libtensor/source/reductions/prod.hpp | 41 + .../source/reductions/reduce_hypot.cpp | 132 + .../source/reductions/reduce_hypot.hpp | 41 + .../reductions/reduction_atomic_support.hpp | 143 + .../source/reductions/reduction_common.cpp | 60 + .../source/reductions/reduction_common.hpp | 41 + .../source/reductions/reduction_over_axis.hpp | 1095 ++++ .../libtensor/source/reductions/sum.cpp | 187 + .../libtensor/source/reductions/sum.hpp | 41 + dpctl/tensor/libtensor/source/tensor_py.cpp | 2 +- dpctl/tests/test_tensor_sum.py | 15 + dpctl/tests/test_usm_ndarray_reductions.py | 195 + 32 files changed, 6735 insertions(+), 2402 deletions(-) delete mode 100644 dpctl/tensor/libtensor/source/reduction_over_axis.cpp delete mode 100644 dpctl/tensor/libtensor/source/reduction_over_axis.hpp create mode 100644 dpctl/tensor/libtensor/source/reductions/argmax.cpp create mode 100644 dpctl/tensor/libtensor/source/reductions/argmax.hpp create mode 100644 dpctl/tensor/libtensor/source/reductions/argmin.cpp create mode 100644 dpctl/tensor/libtensor/source/reductions/argmin.hpp create mode 100644 dpctl/tensor/libtensor/source/reductions/logsumexp.cpp create mode 100644 dpctl/tensor/libtensor/source/reductions/logsumexp.hpp create mode 100644 dpctl/tensor/libtensor/source/reductions/max.cpp create mode 100644 dpctl/tensor/libtensor/source/reductions/max.hpp create mode 100644 dpctl/tensor/libtensor/source/reductions/min.cpp create mode 100644 dpctl/tensor/libtensor/source/reductions/min.hpp create mode 100644 dpctl/tensor/libtensor/source/reductions/prod.cpp create mode 100644 dpctl/tensor/libtensor/source/reductions/prod.hpp create mode 100644 dpctl/tensor/libtensor/source/reductions/reduce_hypot.cpp create mode 100644 dpctl/tensor/libtensor/source/reductions/reduce_hypot.hpp create mode 100644 dpctl/tensor/libtensor/source/reductions/reduction_atomic_support.hpp create mode 100644 dpctl/tensor/libtensor/source/reductions/reduction_common.cpp create mode 100644 dpctl/tensor/libtensor/source/reductions/reduction_common.hpp create mode 100644 dpctl/tensor/libtensor/source/reductions/reduction_over_axis.hpp create mode 100644 dpctl/tensor/libtensor/source/reductions/sum.cpp create mode 100644 dpctl/tensor/libtensor/source/reductions/sum.hpp diff --git a/dpctl/tensor/CMakeLists.txt b/dpctl/tensor/CMakeLists.txt index 5247b4953b..9c02a325bc 100644 --- a/dpctl/tensor/CMakeLists.txt +++ b/dpctl/tensor/CMakeLists.txt @@ -102,6 +102,17 @@ set(_elementwise_sources ${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/true_divide.cpp ${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/trunc.cpp ) +set(_reduction_sources + ${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/reductions/reduction_common.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/reductions/argmax.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/reductions/argmin.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/reductions/logsumexp.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/reductions/max.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/reductions/min.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/reductions/prod.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/reductions/reduce_hypot.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/reductions/sum.cpp +) set(_tensor_impl_sources ${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/tensor_py.cpp ${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/accumulators.cpp @@ -120,11 +131,11 @@ set(_tensor_impl_sources ${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/boolean_reductions.cpp ${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/device_support_queries.cpp ${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/repeat.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/reduction_over_axis.cpp ${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/clip.cpp ) list(APPEND _tensor_impl_sources ${_elementwise_sources} + ${_reduction_sources} ) set(python_module_name _tensor_impl) @@ -138,12 +149,13 @@ endif() set(_no_fast_math_sources ${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/full_ctor.cpp ${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/linear_sequences.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/reduction_over_axis.cpp ${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/clip.cpp ) list(APPEND _no_fast_math_sources ${_elementwise_sources} + ${_reduction_sources} ) + foreach(_src_fn ${_no_fast_math_sources}) get_source_file_property(_cmpl_options_prop ${_src_fn} COMPILE_OPTIONS) set(_combined_options_prop ${_cmpl_options_prop} "${_clang_prefix}-fno-fast-math") diff --git a/dpctl/tensor/__init__.py b/dpctl/tensor/__init__.py index 209a6d4e56..5eee3e9ab9 100644 --- a/dpctl/tensor/__init__.py +++ b/dpctl/tensor/__init__.py @@ -165,7 +165,16 @@ tanh, trunc, ) -from ._reduction import argmax, argmin, max, min, prod, sum +from ._reduction import ( + argmax, + argmin, + logsumexp, + max, + min, + prod, + reduce_hypot, + sum, +) from ._testing import allclose __all__ = [ @@ -324,4 +333,6 @@ "copysign", "rsqrt", "clip", + "logsumexp", + "reduce_hypot", ] diff --git a/dpctl/tensor/_reduction.py b/dpctl/tensor/_reduction.py index aac1c84677..0edc9ac12b 100644 --- a/dpctl/tensor/_reduction.py +++ b/dpctl/tensor/_reduction.py @@ -52,6 +52,28 @@ def _default_reduction_dtype(inp_dt, q): return res_dt +def _default_reduction_dtype_fp_types(inp_dt, q): + """Gives default output data type for given input data + type `inp_dt` when reduction is performed on queue `q` + and the reduction supports only floating-point data types + """ + inp_kind = inp_dt.kind + if inp_kind in "biu": + res_dt = dpt.dtype(ti.default_device_fp_type(q)) + can_cast_v = dpt.can_cast(inp_dt, res_dt) + if not can_cast_v: + _fp64 = q.sycl_device.has_aspect_fp64 + res_dt = dpt.float64 if _fp64 else dpt.float32 + elif inp_kind in "f": + res_dt = dpt.dtype(ti.default_device_fp_type(q)) + if res_dt.itemsize < inp_dt.itemsize: + res_dt = inp_dt + elif inp_kind in "c": + raise TypeError("reduction not defined for complex types") + + return res_dt + + def _reduction_over_axis( x, axis, @@ -91,12 +113,15 @@ def _reduction_over_axis( res_shape = res_shape + (1,) * red_nd inv_perm = sorted(range(nd), key=lambda d: perm[d]) res_shape = tuple(res_shape[i] for i in inv_perm) - return dpt.full( - res_shape, - _identity, - dtype=res_dt, - usm_type=res_usm_type, - sycl_queue=q, + return dpt.astype( + dpt.full( + res_shape, + _identity, + dtype=_default_reduction_type_fn(inp_dt, q), + usm_type=res_usm_type, + sycl_queue=q, + ), + res_dt, ) if red_nd == 0: return dpt.astype(x, res_dt, copy=False) @@ -116,7 +141,7 @@ def _reduction_over_axis( "Automatically determined reduction data type does not " "have direct implementation" ) - tmp_dt = _default_reduction_dtype(inp_dt, q) + tmp_dt = _default_reduction_type_fn(inp_dt, q) tmp = dpt.empty( res_shape, dtype=tmp_dt, usm_type=res_usm_type, sycl_queue=q ) @@ -161,13 +186,13 @@ def sum(x, axis=None, dtype=None, keepdims=False): the returned array will have the default real-valued floating-point data type for the device where input array `x` is allocated. - * If x` has signed integral data type, the returned array + * If `x` has signed integral data type, the returned array will have the default signed integral type for the device where input array `x` is allocated. * If `x` has unsigned integral data type, the returned array will have the default unsigned integral type for the device where input array `x` is allocated. - * If `x` has a complex-valued floating-point data typee, + * If `x` has a complex-valued floating-point data type, the returned array will have the default complex-valued floating-pointer data type for the device where input array `x` is allocated. @@ -222,13 +247,13 @@ def prod(x, axis=None, dtype=None, keepdims=False): the returned array will have the default real-valued floating-point data type for the device where input array `x` is allocated. - * If x` has signed integral data type, the returned array + * If `x` has signed integral data type, the returned array will have the default signed integral type for the device where input array `x` is allocated. * If `x` has unsigned integral data type, the returned array will have the default unsigned integral type for the device where input array `x` is allocated. - * If `x` has a complex-valued floating-point data typee, + * If `x` has a complex-valued floating-point data type, the returned array will have the default complex-valued floating-pointer data type for the device where input array `x` is allocated. @@ -263,6 +288,118 @@ def prod(x, axis=None, dtype=None, keepdims=False): ) +def logsumexp(x, axis=None, dtype=None, keepdims=False): + """logsumexp(x, axis=None, dtype=None, keepdims=False) + + Calculates the logarithm of the sum of exponentials of elements in the + input array `x`. + + Args: + x (usm_ndarray): + input array. + axis (Optional[int, Tuple[int, ...]]): + axis or axes along which values must be computed. If a tuple + of unique integers, values are computed over multiple axes. + If `None`, the result is computed over the entire array. + Default: `None`. + dtype (Optional[dtype]): + data type of the returned array. If `None`, the default data + type is inferred from the "kind" of the input array data type. + * If `x` has a real-valued floating-point data type, + the returned array will have the default real-valued + floating-point data type for the device where input + array `x` is allocated. + * If `x` has a boolean or integral data type, the returned array + will have the default floating point data type for the device + where input array `x` is allocated. + * If `x` has a complex-valued floating-point data type, + an error is raised. + If the data type (either specified or resolved) differs from the + data type of `x`, the input array elements are cast to the + specified data type before computing the result. Default: `None`. + keepdims (Optional[bool]): + if `True`, the reduced axes (dimensions) are included in the result + as singleton dimensions, so that the returned array remains + compatible with the input arrays according to Array Broadcasting + rules. Otherwise, if `False`, the reduced axes are not included in + the returned array. Default: `False`. + Returns: + usm_ndarray: + an array containing the results. If the result was computed over + the entire array, a zero-dimensional array is returned. The returned + array has the data type as described in the `dtype` parameter + description above. + """ + return _reduction_over_axis( + x, + axis, + dtype, + keepdims, + ti._logsumexp_over_axis, + lambda inp_dt, res_dt, *_: ti._logsumexp_over_axis_dtype_supported( + inp_dt, res_dt + ), + _default_reduction_dtype_fp_types, + _identity=-dpt.inf, + ) + + +def reduce_hypot(x, axis=None, dtype=None, keepdims=False): + """reduce_hypot(x, axis=None, dtype=None, keepdims=False) + + Calculates the square root of the sum of squares of elements in the input + array `x`. + + Args: + x (usm_ndarray): + input array. + axis (Optional[int, Tuple[int, ...]]): + axis or axes along which values must be computed. If a tuple + of unique integers, values are computed over multiple axes. + If `None`, the result is computed over the entire array. + Default: `None`. + dtype (Optional[dtype]): + data type of the returned array. If `None`, the default data + type is inferred from the "kind" of the input array data type. + * If `x` has a real-valued floating-point data type, + the returned array will have the default real-valued + floating-point data type for the device where input + array `x` is allocated. + * If `x` has a boolean or integral data type, the returned array + will have the default floating point data type for the device + where input array `x` is allocated. + * If `x` has a complex-valued floating-point data type, + an error is raised. + If the data type (either specified or resolved) differs from the + data type of `x`, the input array elements are cast to the + specified data type before computing the result. Default: `None`. + keepdims (Optional[bool]): + if `True`, the reduced axes (dimensions) are included in the result + as singleton dimensions, so that the returned array remains + compatible with the input arrays according to Array Broadcasting + rules. Otherwise, if `False`, the reduced axes are not included in + the returned array. Default: `False`. + Returns: + usm_ndarray: + an array containing the results. If the result was computed over + the entire array, a zero-dimensional array is returned. The returned + array has the data type as described in the `dtype` parameter + description above. + """ + return _reduction_over_axis( + x, + axis, + dtype, + keepdims, + ti._hypot_over_axis, + lambda inp_dt, res_dt, *_: ti._hypot_over_axis_dtype_supported( + inp_dt, res_dt + ), + _default_reduction_dtype_fp_types, + _identity=0, + ) + + def _comparison_over_axis(x, axis, keepdims, _reduction_fn): if not isinstance(x, dpt.usm_ndarray): raise TypeError(f"Expected dpctl.tensor.usm_ndarray, got {type(x)}") diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/logaddexp.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/logaddexp.hpp index 90b7997a37..6a187da6f4 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/logaddexp.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/logaddexp.hpp @@ -31,6 +31,7 @@ #include #include +#include "utils/math_utils.hpp" #include "utils/offset_utils.hpp" #include "utils/type_dispatch.hpp" #include "utils/type_utils.hpp" @@ -61,7 +62,8 @@ template struct LogAddExpFunctor resT operator()(const argT1 &in1, const argT2 &in2) const { - return impl(in1, in2); + using dpctl::tensor::math_utils::logaddexp; + return logaddexp(in1, in2); } template @@ -79,7 +81,8 @@ template struct LogAddExpFunctor impl_finite(-std::abs(diff[i])); } else { - res[i] = impl(in1[i], in2[i]); + using dpctl::tensor::math_utils::logaddexp; + res[i] = logaddexp(in1[i], in2[i]); } } @@ -87,26 +90,6 @@ template struct LogAddExpFunctor } private: - template T impl(T const &in1, T const &in2) const - { - if (in1 == in2) { // handle signed infinities - const T log2 = std::log(T(2)); - return in1 + log2; - } - else { - const T tmp = in1 - in2; - if (tmp > 0) { - return in1 + std::log1p(std::exp(-tmp)); - } - else if (tmp <= 0) { - return in2 + std::log1p(std::exp(tmp)); - } - else { - return std::numeric_limits::quiet_NaN(); - } - } - } - template T impl_finite(T const &in) const { return (in > 0) ? (in + std::log1p(std::exp(-in))) diff --git a/dpctl/tensor/libtensor/include/kernels/reductions.hpp b/dpctl/tensor/libtensor/include/kernels/reductions.hpp index 7cb97cd4f9..b9e2918c8c 100644 --- a/dpctl/tensor/libtensor/include/kernels/reductions.hpp +++ b/dpctl/tensor/libtensor/include/kernels/reductions.hpp @@ -685,7 +685,6 @@ sycl::event reduction_axis0_over_group_with_atomics_contig_impl( resTy *res_tp = reinterpret_cast(res_cp) + iter_res_offset; constexpr resTy identity_val = su_ns::Identity::value; - ; const sycl::device &d = exec_q.get_device(); const auto &sg_sizes = d.get_info(); @@ -944,8 +943,103 @@ struct CustomReductionOverGroupNoAtomicFunctor } }; +typedef sycl::event (*reduction_strided_impl_fn_ptr)( + sycl::queue &, + size_t, + size_t, + const char *, + char *, + int, + const py::ssize_t *, + py::ssize_t, + py::ssize_t, + int, + const py::ssize_t *, + py::ssize_t, + const std::vector &); + +template +class reduction_over_group_temps_strided_krn; + +template +class custom_reduction_over_group_temps_strided_krn; + +template +class single_reduction_axis0_temps_contig_krn; + +template +class first_reduction_axis0_temps_contig_krn; + +template +class middle_reduction_axis0_temps_contig_krn; + +template +class final_reduction_axis0_temps_contig_krn; + +template +class single_custom_reduction_axis0_temps_contig_krn; + +template +class first_custom_reduction_axis0_temps_contig_krn; + +template +class middle_custom_reduction_axis0_temps_contig_krn; + +template +class final_custom_reduction_axis0_temps_contig_krn; + +template +class single_reduction_axis1_temps_contig_krn; + +template +class first_reduction_axis1_temps_contig_krn; + template -class reduction_over_group_temps_krn; +class middle_reduction_axis1_temps_contig_krn; + +template +class final_reduction_axis1_temps_contig_krn; + +template +class single_custom_reduction_axis1_temps_contig_krn; + +template +class first_custom_reduction_axis1_temps_contig_krn; template -class custom_reduction_over_group_temps_krn; +class middle_custom_reduction_axis1_temps_contig_krn; + +template +class final_custom_reduction_axis1_temps_contig_krn; template sycl::event reduction_over_group_temps_strided_impl( @@ -1020,7 +1122,7 @@ sycl::event reduction_over_group_temps_strided_impl( if constexpr (can_use_reduce_over_group::value) { - using KernelName = class reduction_over_group_temps_krn< + using KernelName = class reduction_over_group_temps_strided_krn< argTy, resTy, ReductionOpT, InputOutputIterIndexerT, ReductionIndexerT>; @@ -1036,9 +1138,10 @@ sycl::event reduction_over_group_temps_strided_impl( else { using SlmT = sycl::local_accessor; SlmT local_memory = SlmT(localRange, cgh); - using KernelName = class custom_reduction_over_group_temps_krn< - argTy, resTy, ReductionOpT, InputOutputIterIndexerT, - ReductionIndexerT, SlmT>; + using KernelName = + class custom_reduction_over_group_temps_strided_krn< + argTy, resTy, ReductionOpT, InputOutputIterIndexerT, + ReductionIndexerT, SlmT>; cgh.parallel_for( sycl::nd_range<1>(globalRange, localRange), @@ -1107,7 +1210,7 @@ sycl::event reduction_over_group_temps_strided_impl( if constexpr (can_use_reduce_over_group::value) { - using KernelName = class reduction_over_group_temps_krn< + using KernelName = class reduction_over_group_temps_strided_krn< argTy, resTy, ReductionOpT, InputOutputIterIndexerT, ReductionIndexerT>; cgh.parallel_for( @@ -1123,9 +1226,10 @@ sycl::event reduction_over_group_temps_strided_impl( else { using SlmT = sycl::local_accessor; SlmT local_memory = SlmT(localRange, cgh); - using KernelName = class custom_reduction_over_group_temps_krn< - argTy, resTy, ReductionOpT, InputOutputIterIndexerT, - ReductionIndexerT, SlmT>; + using KernelName = + class custom_reduction_over_group_temps_strided_krn< + argTy, resTy, ReductionOpT, InputOutputIterIndexerT, + ReductionIndexerT, SlmT>; cgh.parallel_for( sycl::nd_range<1>(globalRange, localRange), CustomReductionOverGroupNoAtomicFunctor< @@ -1180,9 +1284,10 @@ sycl::event reduction_over_group_temps_strided_impl( auto localRange = sycl::range<1>{wg}; if constexpr (can_use_reduce_over_group::value) { - using KernelName = class reduction_over_group_temps_krn< - resTy, resTy, ReductionOpT, InputOutputIterIndexerT, - ReductionIndexerT>; + using KernelName = + class reduction_over_group_temps_strided_krn< + resTy, resTy, ReductionOpT, InputOutputIterIndexerT, + ReductionIndexerT>; cgh.parallel_for( sycl::nd_range<1>(globalRange, localRange), ReductionOverGroupNoAtomicFunctor< @@ -1197,7 +1302,7 @@ sycl::event reduction_over_group_temps_strided_impl( using SlmT = sycl::local_accessor; SlmT local_memory = SlmT(localRange, cgh); using KernelName = - class custom_reduction_over_group_temps_krn< + class custom_reduction_over_group_temps_strided_krn< resTy, resTy, ReductionOpT, InputOutputIterIndexerT, ReductionIndexerT, SlmT>; cgh.parallel_for( @@ -1256,7 +1361,7 @@ sycl::event reduction_over_group_temps_strided_impl( if constexpr (can_use_reduce_over_group::value) { - using KernelName = class reduction_over_group_temps_krn< + using KernelName = class reduction_over_group_temps_strided_krn< argTy, resTy, ReductionOpT, InputOutputIterIndexerT, ReductionIndexerT>; cgh.parallel_for( @@ -1272,9 +1377,10 @@ sycl::event reduction_over_group_temps_strided_impl( else { using SlmT = sycl::local_accessor; SlmT local_memory = SlmT(localRange, cgh); - using KernelName = class custom_reduction_over_group_temps_krn< - argTy, resTy, ReductionOpT, InputOutputIterIndexerT, - ReductionIndexerT, SlmT>; + using KernelName = + class custom_reduction_over_group_temps_strided_krn< + argTy, resTy, ReductionOpT, InputOutputIterIndexerT, + ReductionIndexerT, SlmT>; cgh.parallel_for( sycl::nd_range<1>(globalRange, localRange), CustomReductionOverGroupNoAtomicFunctor< @@ -1304,1220 +1410,3219 @@ sycl::event reduction_over_group_temps_strided_impl( } } -/* @brief Types supported by comparison-reduction code based on atomic_ref */ -template -struct TypePairSupportDataForCompReductionAtomic +template +sycl::event reduction_axis1_over_group_temps_contig_impl( + sycl::queue &exec_q, + size_t iter_nelems, // number of reductions (num. of rows in a matrix + // when reducing over rows) + size_t reduction_nelems, // size of each reduction (length of rows, i.e. + // number of columns) + const char *arg_cp, + char *res_cp, + py::ssize_t iter_arg_offset, + py::ssize_t iter_res_offset, + py::ssize_t reduction_arg_offset, + const std::vector &depends) { + const argTy *arg_tp = reinterpret_cast(arg_cp); + resTy *res_tp = reinterpret_cast(res_cp); - /* value if true a kernel for must be instantiated, false - * otherwise */ - static constexpr bool is_defined = std::disjunction< // disjunction is C++17 - // feature, supported - // by DPC++ - // input int32 - td_ns::TypePairDefinedEntry, - // input uint32 - td_ns::TypePairDefinedEntry, - // input int64 - td_ns::TypePairDefinedEntry, - // input uint64 - td_ns::TypePairDefinedEntry, - // input float - td_ns::TypePairDefinedEntry, - // input double - td_ns::TypePairDefinedEntry, - // fall-through - td_ns::NotDefinedEntry>::is_defined; -}; - -template -struct TypePairSupportDataForCompReductionTemps -{ + constexpr resTy identity_val = su_ns::Identity::value; - static constexpr bool is_defined = std::disjunction< // disjunction is C++17 - // feature, supported - // by DPC++ input bool - td_ns::TypePairDefinedEntry, - // input int8_t - td_ns::TypePairDefinedEntry, + const sycl::device &d = exec_q.get_device(); + const auto &sg_sizes = d.get_info(); + size_t wg = choose_workgroup_size<4>(reduction_nelems, sg_sizes); - // input uint8_t - td_ns::TypePairDefinedEntry, + constexpr size_t preferrered_reductions_per_wi = 8; + // max_max_wg prevents running out of resources on CPU + constexpr size_t max_max_wg = 2048; + size_t max_wg = std::min( + max_max_wg, d.get_info()); - // input int16_t - td_ns::TypePairDefinedEntry, + size_t reductions_per_wi(preferrered_reductions_per_wi); + if (reduction_nelems <= preferrered_reductions_per_wi * max_wg) { + // reduction only requries 1 work-group, can output directly to res + sycl::event comp_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(depends); - // input uint16_t - td_ns::TypePairDefinedEntry, + using InputIterIndexerT = + dpctl::tensor::offset_utils::Strided1DIndexer; + using NoOpIndexerT = dpctl::tensor::offset_utils::NoOpIndexer; + using InputOutputIterIndexerT = + dpctl::tensor::offset_utils::TwoOffsets_CombinedIndexer< + InputIterIndexerT, NoOpIndexerT>; + using ReductionIndexerT = NoOpIndexerT; - // input int32_t - td_ns::TypePairDefinedEntry, - // input uint32_t - td_ns::TypePairDefinedEntry, + InputOutputIterIndexerT in_out_iter_indexer{ + InputIterIndexerT{0, static_cast(iter_nelems), + static_cast(reduction_nelems)}, + NoOpIndexerT{}}; + ReductionIndexerT reduction_indexer{}; - // input int64_t - td_ns::TypePairDefinedEntry, + wg = max_wg; + reductions_per_wi = + std::max(1, (reduction_nelems + wg - 1) / wg); - // input uint32_t - td_ns::TypePairDefinedEntry, + size_t reduction_groups = + (reduction_nelems + reductions_per_wi * wg - 1) / + (reductions_per_wi * wg); + assert(reduction_groups == 1); - // input half - td_ns::TypePairDefinedEntry, + auto globalRange = + sycl::range<1>{iter_nelems * reduction_groups * wg}; + auto localRange = sycl::range<1>{wg}; - // input float - td_ns::TypePairDefinedEntry, + if constexpr (can_use_reduce_over_group::value) + { + using KernelName = + class single_reduction_axis1_temps_contig_krn< + argTy, resTy, ReductionOpT, InputOutputIterIndexerT, + ReductionIndexerT>; - // input double - td_ns::TypePairDefinedEntry, + cgh.parallel_for( + sycl::nd_range<1>(globalRange, localRange), + ReductionOverGroupNoAtomicFunctor< + argTy, resTy, ReductionOpT, InputOutputIterIndexerT, + ReductionIndexerT>(arg_tp, res_tp, ReductionOpT(), + identity_val, in_out_iter_indexer, + reduction_indexer, reduction_nelems, + iter_nelems, reductions_per_wi)); + } + else { + using SlmT = sycl::local_accessor; + SlmT local_memory = SlmT(localRange, cgh); + using KernelName = + class single_custom_reduction_axis1_temps_contig_krn< + argTy, resTy, ReductionOpT, InputOutputIterIndexerT, + ReductionIndexerT, SlmT>; - // input std::complex - td_ns::TypePairDefinedEntry, - outTy, - std::complex>, + cgh.parallel_for( + sycl::nd_range<1>(globalRange, localRange), + CustomReductionOverGroupNoAtomicFunctor< + argTy, resTy, ReductionOpT, InputOutputIterIndexerT, + ReductionIndexerT, SlmT>( + arg_tp, res_tp, ReductionOpT(), identity_val, + in_out_iter_indexer, reduction_indexer, local_memory, + reduction_nelems, iter_nelems, reductions_per_wi)); + } + }); + return comp_ev; + } + else { + // more than one work-groups is needed, requires a temporary + size_t reduction_groups = + (reduction_nelems + preferrered_reductions_per_wi * wg - 1) / + (preferrered_reductions_per_wi * wg); + assert(reduction_groups > 1); - td_ns::TypePairDefinedEntry, - outTy, - std::complex>, + size_t second_iter_reduction_groups_ = + (reduction_groups + preferrered_reductions_per_wi * wg - 1) / + (preferrered_reductions_per_wi * wg); - // fall-through - td_ns::NotDefinedEntry>::is_defined; -}; + resTy *partially_reduced_tmp = sycl::malloc_device( + iter_nelems * (reduction_groups + second_iter_reduction_groups_), + exec_q); + resTy *partially_reduced_tmp2 = nullptr; -template -struct MaxOverAxisAtomicStridedFactory -{ - fnT get() const - { - if constexpr (TypePairSupportDataForCompReductionAtomic< - srcTy, dstTy>::is_defined) - { - if constexpr (std::is_floating_point::value) { - using ReductionOpT = su_ns::Maximum; - return dpctl::tensor::kernels:: - reduction_over_group_with_atomics_strided_impl< - srcTy, dstTy, ReductionOpT>; - } - else { - using ReductionOpT = sycl::maximum; - return dpctl::tensor::kernels:: - reduction_over_group_with_atomics_strided_impl< - srcTy, dstTy, ReductionOpT>; - } + if (partially_reduced_tmp == nullptr) { + throw std::runtime_error("Unabled to allocate device_memory"); } else { - return nullptr; + partially_reduced_tmp2 = + partially_reduced_tmp + reduction_groups * iter_nelems; } - } -}; -template -struct MaxOverAxisTempsStridedFactory -{ - fnT get() const - { - if constexpr (TypePairSupportDataForCompReductionTemps< - srcTy, dstTy>::is_defined) - { - if constexpr (std::is_integral_v && - !std::is_same_v) { - using ReductionOpT = sycl::maximum; - return dpctl::tensor::kernels:: - reduction_over_group_temps_strided_impl; - } - else { - using ReductionOpT = su_ns::Maximum; - return dpctl::tensor::kernels:: - reduction_over_group_temps_strided_impl; - } - } - else { - return nullptr; - } - } -}; + const sycl::event &first_reduction_ev = exec_q.submit([&](sycl::handler + &cgh) { + cgh.depends_on(depends); -template -struct MaxOverAxis1AtomicContigFactory -{ - fnT get() const - { - if constexpr (TypePairSupportDataForCompReductionAtomic< - srcTy, dstTy>::is_defined) - { - if constexpr (std::is_floating_point::value) { - using ReductionOpT = su_ns::Maximum; - return dpctl::tensor::kernels:: - reduction_axis1_over_group_with_atomics_contig_impl< - srcTy, dstTy, ReductionOpT>; - } - else { - using ReductionOpT = sycl::maximum; - return dpctl::tensor::kernels:: - reduction_axis1_over_group_with_atomics_contig_impl< - srcTy, dstTy, ReductionOpT>; - } - } - else { - return nullptr; - } - } -}; + using NoOpIndexerT = dpctl::tensor::offset_utils::NoOpIndexer; + using RowsIndexerT = dpctl::tensor::offset_utils::Strided1DIndexer; + using InputOutputIterIndexerT = + dpctl::tensor::offset_utils::TwoOffsets_CombinedIndexer< + RowsIndexerT, NoOpIndexerT>; + using ReductionIndexerT = NoOpIndexerT; -template -struct MaxOverAxis0AtomicContigFactory -{ - fnT get() const - { - if constexpr (TypePairSupportDataForCompReductionAtomic< - srcTy, dstTy>::is_defined) - { - if constexpr (std::is_floating_point::value) { - using ReductionOpT = su_ns::Maximum; - return dpctl::tensor::kernels:: - reduction_axis0_over_group_with_atomics_contig_impl< - srcTy, dstTy, ReductionOpT>; - } - else { - using ReductionOpT = sycl::maximum; - return dpctl::tensor::kernels:: - reduction_axis0_over_group_with_atomics_contig_impl< - srcTy, dstTy, ReductionOpT>; - } - } - else { - return nullptr; - } - } -}; + RowsIndexerT rows_indexer{ + 0, static_cast(iter_nelems), + static_cast(reduction_nelems)}; + NoOpIndexerT noop_tmp_indexer{}; + InputOutputIterIndexerT in_out_iter_indexer{rows_indexer, + noop_tmp_indexer}; + ReductionIndexerT reduction_indexer{}; -template -struct MinOverAxisAtomicStridedFactory -{ - fnT get() const - { - if constexpr (TypePairSupportDataForCompReductionAtomic< - srcTy, dstTy>::is_defined) - { - if constexpr (std::is_floating_point::value) { - using ReductionOpT = su_ns::Minimum; - return dpctl::tensor::kernels:: - reduction_over_group_with_atomics_strided_impl< - srcTy, dstTy, ReductionOpT>; - } - else { - using ReductionOpT = sycl::minimum; - return dpctl::tensor::kernels:: - reduction_over_group_with_atomics_strided_impl< - srcTy, dstTy, ReductionOpT>; - } - } - else { - return nullptr; - } - } -}; + auto globalRange = + sycl::range<1>{iter_nelems * reduction_groups * wg}; + auto localRange = sycl::range<1>{wg}; -template -struct MinOverAxisTempsStridedFactory -{ - fnT get() const - { - if constexpr (TypePairSupportDataForCompReductionTemps< - srcTy, dstTy>::is_defined) - { - if constexpr (std::is_integral_v && - !std::is_same_v) { - using ReductionOpT = sycl::minimum; - return dpctl::tensor::kernels:: - reduction_over_group_temps_strided_impl; + if constexpr (can_use_reduce_over_group::value) + { + using KernelName = class first_reduction_axis1_temps_contig_krn< + argTy, resTy, ReductionOpT, InputOutputIterIndexerT, + ReductionIndexerT>; + cgh.parallel_for( + sycl::nd_range<1>(globalRange, localRange), + ReductionOverGroupNoAtomicFunctor< + argTy, resTy, ReductionOpT, InputOutputIterIndexerT, + ReductionIndexerT>( + arg_tp, partially_reduced_tmp, ReductionOpT(), + identity_val, in_out_iter_indexer, reduction_indexer, + reduction_nelems, iter_nelems, + preferrered_reductions_per_wi)); } else { - using ReductionOpT = su_ns::Minimum; - return dpctl::tensor::kernels:: - reduction_over_group_temps_strided_impl; + using SlmT = sycl::local_accessor; + SlmT local_memory = SlmT(localRange, cgh); + using KernelName = + class first_custom_reduction_axis1_temps_contig_krn< + argTy, resTy, ReductionOpT, InputOutputIterIndexerT, + ReductionIndexerT, SlmT>; + cgh.parallel_for( + sycl::nd_range<1>(globalRange, localRange), + CustomReductionOverGroupNoAtomicFunctor< + argTy, resTy, ReductionOpT, InputOutputIterIndexerT, + ReductionIndexerT, SlmT>( + arg_tp, partially_reduced_tmp, ReductionOpT(), + identity_val, in_out_iter_indexer, reduction_indexer, + local_memory, reduction_nelems, iter_nelems, + preferrered_reductions_per_wi)); } - } - else { - return nullptr; - } - } -}; + }); -template -struct MinOverAxis1AtomicContigFactory -{ - fnT get() const - { - if constexpr (TypePairSupportDataForCompReductionAtomic< - srcTy, dstTy>::is_defined) - { - if constexpr (std::is_floating_point::value) { - using ReductionOpT = su_ns::Minimum; - return dpctl::tensor::kernels:: - reduction_axis1_over_group_with_atomics_contig_impl< - srcTy, dstTy, ReductionOpT>; - } - else { - using ReductionOpT = sycl::minimum; - return dpctl::tensor::kernels:: - reduction_axis1_over_group_with_atomics_contig_impl< - srcTy, dstTy, ReductionOpT>; - } - } - else { - return nullptr; - } - } -}; + size_t remaining_reduction_nelems = reduction_groups; -template -struct MinOverAxis0AtomicContigFactory -{ - fnT get() const - { - if constexpr (TypePairSupportDataForCompReductionAtomic< - srcTy, dstTy>::is_defined) - { - if constexpr (std::is_floating_point::value) { - using ReductionOpT = su_ns::Minimum; - return dpctl::tensor::kernels:: - reduction_axis0_over_group_with_atomics_contig_impl< - srcTy, dstTy, ReductionOpT>; - } - else { - using ReductionOpT = sycl::minimum; - return dpctl::tensor::kernels:: - reduction_axis0_over_group_with_atomics_contig_impl< - srcTy, dstTy, ReductionOpT>; - } - } - else { - return nullptr; - } - } -}; - -// Sum - -/* @brief Types supported by plus-reduction code based on atomic_ref */ -template -struct TypePairSupportDataForSumReductionAtomic -{ - - /* value if true a kernel for must be instantiated, false - * otherwise */ - static constexpr bool is_defined = std::disjunction< // disjunction is C++17 - // feature, supported - // by DPC++ input bool - td_ns::TypePairDefinedEntry, - td_ns::TypePairDefinedEntry, - td_ns::TypePairDefinedEntry, - td_ns::TypePairDefinedEntry, - td_ns::TypePairDefinedEntry, - td_ns::TypePairDefinedEntry, - // input int8 - td_ns::TypePairDefinedEntry, - td_ns::TypePairDefinedEntry, - td_ns::TypePairDefinedEntry, - td_ns::TypePairDefinedEntry, - // input uint8 - td_ns::TypePairDefinedEntry, - td_ns::TypePairDefinedEntry, - td_ns::TypePairDefinedEntry, - td_ns::TypePairDefinedEntry, - td_ns::TypePairDefinedEntry, - td_ns::TypePairDefinedEntry, - // input int16 - td_ns::TypePairDefinedEntry, - td_ns::TypePairDefinedEntry, - td_ns::TypePairDefinedEntry, - td_ns::TypePairDefinedEntry, - // input uint16 - td_ns::TypePairDefinedEntry, - td_ns::TypePairDefinedEntry, - td_ns::TypePairDefinedEntry, - td_ns::TypePairDefinedEntry, - td_ns::TypePairDefinedEntry, - td_ns::TypePairDefinedEntry, - // input int32 - td_ns::TypePairDefinedEntry, - td_ns::TypePairDefinedEntry, - td_ns::TypePairDefinedEntry, - td_ns::TypePairDefinedEntry, - // input uint32 - td_ns::TypePairDefinedEntry, - td_ns::TypePairDefinedEntry, - td_ns::TypePairDefinedEntry, - td_ns::TypePairDefinedEntry, - td_ns::TypePairDefinedEntry, - // input int64 - td_ns::TypePairDefinedEntry, - td_ns::TypePairDefinedEntry, - // input uint64 - td_ns::TypePairDefinedEntry, - td_ns::TypePairDefinedEntry, - // input half - td_ns::TypePairDefinedEntry, - td_ns::TypePairDefinedEntry, - // input float - td_ns::TypePairDefinedEntry, - td_ns::TypePairDefinedEntry, - // input double - td_ns::TypePairDefinedEntry, - // fall-through - td_ns::NotDefinedEntry>::is_defined; -}; - -template -struct TypePairSupportDataForSumReductionTemps -{ - - static constexpr bool is_defined = std::disjunction< // disjunction is C++17 - // feature, supported - // by DPC++ input bool - td_ns::TypePairDefinedEntry, - td_ns::TypePairDefinedEntry, - td_ns::TypePairDefinedEntry, - td_ns::TypePairDefinedEntry, - td_ns::TypePairDefinedEntry, - td_ns::TypePairDefinedEntry, - td_ns::TypePairDefinedEntry, - td_ns::TypePairDefinedEntry, + resTy *temp_arg = partially_reduced_tmp; + resTy *temp2_arg = partially_reduced_tmp2; + sycl::event dependent_ev = first_reduction_ev; - // input int8_t - td_ns::TypePairDefinedEntry, - td_ns::TypePairDefinedEntry, - td_ns::TypePairDefinedEntry, - td_ns::TypePairDefinedEntry, + while (remaining_reduction_nelems > + preferrered_reductions_per_wi * max_wg) { + size_t reduction_groups_ = + (remaining_reduction_nelems + + preferrered_reductions_per_wi * wg - 1) / + (preferrered_reductions_per_wi * wg); + assert(reduction_groups_ > 1); - // input uint8_t - td_ns::TypePairDefinedEntry, - td_ns::TypePairDefinedEntry, - td_ns::TypePairDefinedEntry, - td_ns::TypePairDefinedEntry, - td_ns::TypePairDefinedEntry, - td_ns::TypePairDefinedEntry, - td_ns::TypePairDefinedEntry, + // keep reducing + sycl::event partial_reduction_ev = exec_q.submit([&](sycl::handler + &cgh) { + cgh.depends_on(dependent_ev); - // input int16_t - td_ns::TypePairDefinedEntry, - td_ns::TypePairDefinedEntry, - td_ns::TypePairDefinedEntry, + using InputIndexerT = + dpctl::tensor::offset_utils::Strided1DIndexer; + using ResIndexerT = dpctl::tensor::offset_utils::NoOpIndexer; + using InputOutputIterIndexerT = + dpctl::tensor::offset_utils::TwoOffsets_CombinedIndexer< + InputIndexerT, ResIndexerT>; + using ReductionIndexerT = + dpctl::tensor::offset_utils::NoOpIndexer; - // input uint16_t - td_ns::TypePairDefinedEntry, - td_ns::TypePairDefinedEntry, - td_ns::TypePairDefinedEntry, - td_ns::TypePairDefinedEntry, - td_ns::TypePairDefinedEntry, + InputIndexerT inp_indexer{ + 0, static_cast(iter_nelems), + static_cast(reduction_groups_)}; + ResIndexerT res_iter_indexer{}; - // input int32_t - td_ns::TypePairDefinedEntry, - td_ns::TypePairDefinedEntry, + InputOutputIterIndexerT in_out_iter_indexer{inp_indexer, + res_iter_indexer}; + ReductionIndexerT reduction_indexer{}; - // input uint32_t - td_ns::TypePairDefinedEntry, - td_ns::TypePairDefinedEntry, + auto globalRange = + sycl::range<1>{iter_nelems * reduction_groups_ * wg}; + auto localRange = sycl::range<1>{wg}; + if constexpr (can_use_reduce_over_group::value) { + using KernelName = + class middle_reduction_axis1_temps_contig_krn< + resTy, resTy, ReductionOpT, InputOutputIterIndexerT, + ReductionIndexerT>; + cgh.parallel_for( + sycl::nd_range<1>(globalRange, localRange), + ReductionOverGroupNoAtomicFunctor< + resTy, resTy, ReductionOpT, InputOutputIterIndexerT, + ReductionIndexerT>( + temp_arg, temp2_arg, ReductionOpT(), identity_val, + in_out_iter_indexer, reduction_indexer, + remaining_reduction_nelems, iter_nelems, + preferrered_reductions_per_wi)); + } + else { + using SlmT = sycl::local_accessor; + SlmT local_memory = SlmT(localRange, cgh); + using KernelName = + class middle_custom_reduction_axis1_temps_contig_krn< + resTy, resTy, ReductionOpT, InputOutputIterIndexerT, + ReductionIndexerT, SlmT>; + cgh.parallel_for( + sycl::nd_range<1>(globalRange, localRange), + CustomReductionOverGroupNoAtomicFunctor< + resTy, resTy, ReductionOpT, InputOutputIterIndexerT, + ReductionIndexerT, SlmT>( + temp_arg, temp2_arg, ReductionOpT(), identity_val, + in_out_iter_indexer, reduction_indexer, + local_memory, remaining_reduction_nelems, + iter_nelems, preferrered_reductions_per_wi)); + } + }); - // input int64_t - td_ns::TypePairDefinedEntry, + remaining_reduction_nelems = reduction_groups_; + std::swap(temp_arg, temp2_arg); + dependent_ev = std::move(partial_reduction_ev); + } - // input uint32_t - td_ns::TypePairDefinedEntry, + // final reduction to res + sycl::event final_reduction_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(dependent_ev); - // input half - td_ns::TypePairDefinedEntry, - td_ns::TypePairDefinedEntry, - td_ns::TypePairDefinedEntry, - td_ns:: - TypePairDefinedEntry>, - td_ns::TypePairDefinedEntry>, + using InputIndexerT = dpctl::tensor::offset_utils::Strided1DIndexer; + using ResIndexerT = dpctl::tensor::offset_utils::NoOpIndexer; + using InputOutputIterIndexerT = + dpctl::tensor::offset_utils::TwoOffsets_CombinedIndexer< + InputIndexerT, ResIndexerT>; + using ReductionIndexerT = dpctl::tensor::offset_utils::NoOpIndexer; - // input float - td_ns::TypePairDefinedEntry, - td_ns::TypePairDefinedEntry, - td_ns::TypePairDefinedEntry>, - td_ns::TypePairDefinedEntry>, + InputIndexerT inp_indexer{ + 0, static_cast(iter_nelems), + static_cast(remaining_reduction_nelems)}; + ResIndexerT res_iter_indexer{}; - // input double - td_ns::TypePairDefinedEntry, - td_ns::TypePairDefinedEntry>, + InputOutputIterIndexerT in_out_iter_indexer{inp_indexer, + res_iter_indexer}; + ReductionIndexerT reduction_indexer{}; - // input std::complex - td_ns::TypePairDefinedEntry, - outTy, - std::complex>, - td_ns::TypePairDefinedEntry, - outTy, - std::complex>, + wg = max_wg; + reductions_per_wi = + std::max(1, (remaining_reduction_nelems + wg - 1) / wg); - td_ns::TypePairDefinedEntry, - outTy, - std::complex>, + size_t reduction_groups = + (remaining_reduction_nelems + reductions_per_wi * wg - 1) / + (reductions_per_wi * wg); + assert(reduction_groups == 1); - // fall-throug - td_ns::NotDefinedEntry>::is_defined; -}; + auto globalRange = + sycl::range<1>{iter_nelems * reduction_groups * wg}; + auto localRange = sycl::range<1>{wg}; -template -struct SumOverAxisAtomicStridedFactory -{ - fnT get() const - { - if constexpr (TypePairSupportDataForSumReductionAtomic< - srcTy, dstTy>::is_defined) - { - using ReductionOpT = sycl::plus; - return dpctl::tensor::kernels:: - reduction_over_group_with_atomics_strided_impl; - } - else { - return nullptr; - } - } -}; + if constexpr (can_use_reduce_over_group::value) + { + using KernelName = class final_reduction_axis1_temps_contig_krn< + argTy, resTy, ReductionOpT, InputOutputIterIndexerT, + ReductionIndexerT>; + cgh.parallel_for( + sycl::nd_range<1>(globalRange, localRange), + ReductionOverGroupNoAtomicFunctor< + resTy, resTy, ReductionOpT, InputOutputIterIndexerT, + ReductionIndexerT>(temp_arg, res_tp, ReductionOpT(), + identity_val, in_out_iter_indexer, + reduction_indexer, + remaining_reduction_nelems, + iter_nelems, reductions_per_wi)); + } + else { + using SlmT = sycl::local_accessor; + SlmT local_memory = SlmT(localRange, cgh); + using KernelName = + class final_custom_reduction_axis1_temps_contig_krn< + argTy, resTy, ReductionOpT, InputOutputIterIndexerT, + ReductionIndexerT, SlmT>; + cgh.parallel_for( + sycl::nd_range<1>(globalRange, localRange), + CustomReductionOverGroupNoAtomicFunctor< + resTy, resTy, ReductionOpT, InputOutputIterIndexerT, + ReductionIndexerT, SlmT>( + temp_arg, res_tp, ReductionOpT(), identity_val, + in_out_iter_indexer, reduction_indexer, local_memory, + remaining_reduction_nelems, iter_nelems, + reductions_per_wi)); + } + }); -template -struct SumOverAxisTempsStridedFactory -{ - fnT get() const - { - if constexpr (TypePairSupportDataForSumReductionTemps< - srcTy, dstTy>::is_defined) { - using ReductionOpT = sycl::plus; - return dpctl::tensor::kernels:: - reduction_over_group_temps_strided_impl; - } - else { - return nullptr; - } - } -}; + sycl::event cleanup_host_task_event = + exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(final_reduction_ev); + const sycl::context &ctx = exec_q.get_context(); -template -struct SumOverAxis1AtomicContigFactory -{ - fnT get() const - { - if constexpr (TypePairSupportDataForSumReductionAtomic< - srcTy, dstTy>::is_defined) - { - using ReductionOpT = sycl::plus; - return dpctl::tensor::kernels:: - reduction_axis1_over_group_with_atomics_contig_impl< - srcTy, dstTy, ReductionOpT>; - } - else { - return nullptr; - } + cgh.host_task([ctx, partially_reduced_tmp] { + sycl::free(partially_reduced_tmp, ctx); + }); + }); + + // FIXME: do not return host-task event + // Instead collect all host-tasks to a list + + return cleanup_host_task_event; } -}; +} -template -struct SumOverAxis0AtomicContigFactory +template +sycl::event reduction_axis0_over_group_temps_contig_impl( + sycl::queue &exec_q, + size_t iter_nelems, // number of reductions (num. of rows in a matrix + // when reducing over rows) + size_t reduction_nelems, // size of each reduction (length of rows, i.e. + // number of columns) + const char *arg_cp, + char *res_cp, + py::ssize_t iter_arg_offset, + py::ssize_t iter_res_offset, + py::ssize_t reduction_arg_offset, + const std::vector &depends) { - fnT get() const - { - if constexpr (TypePairSupportDataForSumReductionAtomic< - srcTy, dstTy>::is_defined) - { - using ReductionOpT = sycl::plus; - return dpctl::tensor::kernels:: - reduction_axis0_over_group_with_atomics_contig_impl< - srcTy, dstTy, ReductionOpT>; - } - else { - return nullptr; - } - } -}; + const argTy *arg_tp = reinterpret_cast(arg_cp); + resTy *res_tp = reinterpret_cast(res_cp); -// Product + constexpr resTy identity_val = su_ns::Identity::value; -/* @brief Types supported by plus-reduction code based on atomic_ref */ -template -struct TypePairSupportDataForProductReductionAtomic -{ + const sycl::device &d = exec_q.get_device(); + const auto &sg_sizes = d.get_info(); + size_t wg = choose_workgroup_size<4>(reduction_nelems, sg_sizes); - /* value if true a kernel for must be instantiated, false - * otherwise */ - static constexpr bool is_defined = std::disjunction< // disjunction is C++17 - // feature, supported - // by DPC++ input bool - td_ns::TypePairDefinedEntry, - td_ns::TypePairDefinedEntry, - td_ns::TypePairDefinedEntry, - td_ns::TypePairDefinedEntry, - td_ns::TypePairDefinedEntry, - td_ns::TypePairDefinedEntry, - // input int8 - td_ns::TypePairDefinedEntry, - td_ns::TypePairDefinedEntry, - td_ns::TypePairDefinedEntry, - td_ns::TypePairDefinedEntry, - // input uint8 - td_ns::TypePairDefinedEntry, - td_ns::TypePairDefinedEntry, - td_ns::TypePairDefinedEntry, - td_ns::TypePairDefinedEntry, - td_ns::TypePairDefinedEntry, - td_ns::TypePairDefinedEntry, - // input int16 - td_ns::TypePairDefinedEntry, - td_ns::TypePairDefinedEntry, - td_ns::TypePairDefinedEntry, - td_ns::TypePairDefinedEntry, - // input uint16 - td_ns::TypePairDefinedEntry, - td_ns::TypePairDefinedEntry, - td_ns::TypePairDefinedEntry, - td_ns::TypePairDefinedEntry, - td_ns::TypePairDefinedEntry, - td_ns::TypePairDefinedEntry, - // input int32 - td_ns::TypePairDefinedEntry, - td_ns::TypePairDefinedEntry, - td_ns::TypePairDefinedEntry, - td_ns::TypePairDefinedEntry, - // input uint32 - td_ns::TypePairDefinedEntry, - td_ns::TypePairDefinedEntry, - td_ns::TypePairDefinedEntry, - td_ns::TypePairDefinedEntry, - td_ns::TypePairDefinedEntry, - // input int64 - td_ns::TypePairDefinedEntry, - td_ns::TypePairDefinedEntry, - // input uint64 - td_ns::TypePairDefinedEntry, - td_ns::TypePairDefinedEntry, - // input half - td_ns::TypePairDefinedEntry, - td_ns::TypePairDefinedEntry, - // input float - td_ns::TypePairDefinedEntry, - td_ns::TypePairDefinedEntry, - // input double - td_ns::TypePairDefinedEntry, - // fall-through - td_ns::NotDefinedEntry>::is_defined; -}; + constexpr size_t preferrered_reductions_per_wi = 8; + // max_max_wg prevents running out of resources on CPU + constexpr size_t max_max_wg = 2048; + size_t max_wg = std::min( + max_max_wg, d.get_info()); -template -struct TypePairSupportDataForProductReductionTemps -{ + size_t reductions_per_wi(preferrered_reductions_per_wi); + if (reduction_nelems <= preferrered_reductions_per_wi * max_wg) { + // reduction only requries 1 work-group, can output directly to res + sycl::event comp_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(depends); - static constexpr bool is_defined = std::disjunction< // disjunction is C++17 - // feature, supported - // by DPC++ input bool - td_ns::TypePairDefinedEntry, - td_ns::TypePairDefinedEntry, - td_ns::TypePairDefinedEntry, - td_ns::TypePairDefinedEntry, - td_ns::TypePairDefinedEntry, - td_ns::TypePairDefinedEntry, - td_ns::TypePairDefinedEntry, - td_ns::TypePairDefinedEntry, + using NoOpIndexerT = dpctl::tensor::offset_utils::NoOpIndexer; + using ColsIndexerT = dpctl::tensor::offset_utils::Strided1DIndexer; + using InputOutputIterIndexerT = + dpctl::tensor::offset_utils::TwoOffsets_CombinedIndexer< + NoOpIndexerT, NoOpIndexerT>; + using ReductionIndexerT = ColsIndexerT; - // input int8_t - td_ns::TypePairDefinedEntry, - td_ns::TypePairDefinedEntry, - td_ns::TypePairDefinedEntry, - td_ns::TypePairDefinedEntry, + NoOpIndexerT columns_indexer{}; + NoOpIndexerT result_indexer{}; + InputOutputIterIndexerT in_out_iter_indexer{columns_indexer, + result_indexer}; + ReductionIndexerT reduction_indexer{ + 0, /* size */ static_cast(reduction_nelems), + /* step */ static_cast(iter_nelems)}; - // input uint8_t - td_ns::TypePairDefinedEntry, - td_ns::TypePairDefinedEntry, - td_ns::TypePairDefinedEntry, - td_ns::TypePairDefinedEntry, - td_ns::TypePairDefinedEntry, - td_ns::TypePairDefinedEntry, - td_ns::TypePairDefinedEntry, + wg = max_wg; + reductions_per_wi = + std::max(1, (reduction_nelems + wg - 1) / wg); - // input int16_t - td_ns::TypePairDefinedEntry, - td_ns::TypePairDefinedEntry, - td_ns::TypePairDefinedEntry, + size_t reduction_groups = + (reduction_nelems + reductions_per_wi * wg - 1) / + (reductions_per_wi * wg); + assert(reduction_groups == 1); - // input uint16_t - td_ns::TypePairDefinedEntry, - td_ns::TypePairDefinedEntry, - td_ns::TypePairDefinedEntry, - td_ns::TypePairDefinedEntry, - td_ns::TypePairDefinedEntry, + auto globalRange = + sycl::range<1>{iter_nelems * reduction_groups * wg}; + auto localRange = sycl::range<1>{wg}; - // input int32_t - td_ns::TypePairDefinedEntry, - td_ns::TypePairDefinedEntry, + if constexpr (can_use_reduce_over_group::value) + { + using KernelName = + class single_reduction_axis0_temps_contig_krn< + argTy, resTy, ReductionOpT, InputOutputIterIndexerT, + ReductionIndexerT>; - // input uint32_t - td_ns::TypePairDefinedEntry, - td_ns::TypePairDefinedEntry, + cgh.parallel_for( + sycl::nd_range<1>(globalRange, localRange), + ReductionOverGroupNoAtomicFunctor< + argTy, resTy, ReductionOpT, InputOutputIterIndexerT, + ReductionIndexerT>(arg_tp, res_tp, ReductionOpT(), + identity_val, in_out_iter_indexer, + reduction_indexer, reduction_nelems, + iter_nelems, reductions_per_wi)); + } + else { + using SlmT = sycl::local_accessor; + SlmT local_memory = SlmT(localRange, cgh); + using KernelName = + class single_custom_reduction_axis0_temps_contig_krn< + argTy, resTy, ReductionOpT, InputOutputIterIndexerT, + ReductionIndexerT, SlmT>; - // input int64_t - td_ns::TypePairDefinedEntry, + cgh.parallel_for( + sycl::nd_range<1>(globalRange, localRange), + CustomReductionOverGroupNoAtomicFunctor< + argTy, resTy, ReductionOpT, InputOutputIterIndexerT, + ReductionIndexerT, SlmT>( + arg_tp, res_tp, ReductionOpT(), identity_val, + in_out_iter_indexer, reduction_indexer, local_memory, + reduction_nelems, iter_nelems, reductions_per_wi)); + } + }); + return comp_ev; + } + else { + // more than one work-groups is needed, requires a temporary + size_t reduction_groups = + (reduction_nelems + preferrered_reductions_per_wi * wg - 1) / + (preferrered_reductions_per_wi * wg); + assert(reduction_groups > 1); - // input uint32_t - td_ns::TypePairDefinedEntry, + size_t second_iter_reduction_groups_ = + (reduction_groups + preferrered_reductions_per_wi * wg - 1) / + (preferrered_reductions_per_wi * wg); - // input half - td_ns::TypePairDefinedEntry, - td_ns::TypePairDefinedEntry, - td_ns::TypePairDefinedEntry, - td_ns:: - TypePairDefinedEntry>, - td_ns::TypePairDefinedEntry>, + resTy *partially_reduced_tmp = sycl::malloc_device( + iter_nelems * (reduction_groups + second_iter_reduction_groups_), + exec_q); + resTy *partially_reduced_tmp2 = nullptr; - // input float - td_ns::TypePairDefinedEntry, - td_ns::TypePairDefinedEntry, - td_ns::TypePairDefinedEntry>, - td_ns::TypePairDefinedEntry>, + if (partially_reduced_tmp == nullptr) { + throw std::runtime_error("Unabled to allocate device_memory"); + } + else { + partially_reduced_tmp2 = + partially_reduced_tmp + reduction_groups * iter_nelems; + } - // input double - td_ns::TypePairDefinedEntry, - td_ns::TypePairDefinedEntry>, + const sycl::event &first_reduction_ev = exec_q.submit([&](sycl::handler + &cgh) { + cgh.depends_on(depends); - // input std::complex - td_ns::TypePairDefinedEntry, - outTy, - std::complex>, - td_ns::TypePairDefinedEntry, - outTy, - std::complex>, + using NoOpIndexerT = dpctl::tensor::offset_utils::NoOpIndexer; + using ColsIndexerT = dpctl::tensor::offset_utils::Strided1DIndexer; + using InputOutputIterIndexerT = + dpctl::tensor::offset_utils::TwoOffsets_CombinedIndexer< + NoOpIndexerT, NoOpIndexerT>; + using ReductionIndexerT = ColsIndexerT; - td_ns::TypePairDefinedEntry, - outTy, - std::complex>, + NoOpIndexerT columns_indexer{}; + NoOpIndexerT noop_tmp_indexer{}; + InputOutputIterIndexerT in_out_iter_indexer{columns_indexer, + noop_tmp_indexer}; + ReductionIndexerT reduction_indexer{ + 0, /* size */ static_cast(reduction_nelems), + /* step */ static_cast(iter_nelems)}; - // fall-throug - td_ns::NotDefinedEntry>::is_defined; -}; + auto globalRange = + sycl::range<1>{iter_nelems * reduction_groups * wg}; + auto localRange = sycl::range<1>{wg}; -template -struct ProductOverAxisAtomicStridedFactory -{ - fnT get() const - { - if constexpr (TypePairSupportDataForProductReductionAtomic< - srcTy, dstTy>::is_defined) - { - using ReductionOpT = sycl::multiplies; - return dpctl::tensor::kernels:: - reduction_over_group_with_atomics_strided_impl; - } - else { - return nullptr; - } - } -}; + if constexpr (can_use_reduce_over_group::value) + { + using KernelName = class first_reduction_axis0_temps_contig_krn< + argTy, resTy, ReductionOpT, InputOutputIterIndexerT, + ReductionIndexerT>; + cgh.parallel_for( + sycl::nd_range<1>(globalRange, localRange), + ReductionOverGroupNoAtomicFunctor< + argTy, resTy, ReductionOpT, InputOutputIterIndexerT, + ReductionIndexerT>( + arg_tp, partially_reduced_tmp, ReductionOpT(), + identity_val, in_out_iter_indexer, reduction_indexer, + reduction_nelems, iter_nelems, + preferrered_reductions_per_wi)); + } + else { + using SlmT = sycl::local_accessor; + SlmT local_memory = SlmT(localRange, cgh); + using KernelName = + class first_custom_reduction_axis0_temps_contig_krn< + argTy, resTy, ReductionOpT, InputOutputIterIndexerT, + ReductionIndexerT, SlmT>; + cgh.parallel_for( + sycl::nd_range<1>(globalRange, localRange), + CustomReductionOverGroupNoAtomicFunctor< + argTy, resTy, ReductionOpT, InputOutputIterIndexerT, + ReductionIndexerT, SlmT>( + arg_tp, partially_reduced_tmp, ReductionOpT(), + identity_val, in_out_iter_indexer, reduction_indexer, + local_memory, reduction_nelems, iter_nelems, + preferrered_reductions_per_wi)); + } + }); -template -struct ProductOverAxisTempsStridedFactory -{ - fnT get() const - { - if constexpr (TypePairSupportDataForProductReductionTemps< - srcTy, dstTy>::is_defined) - { - using ReductionOpT = sycl::multiplies; - return dpctl::tensor::kernels:: - reduction_over_group_temps_strided_impl; - } - else { - return nullptr; - } - } -}; + size_t remaining_reduction_nelems = reduction_groups; -template -struct ProductOverAxis1AtomicContigFactory -{ - fnT get() const - { - if constexpr (TypePairSupportDataForProductReductionAtomic< - srcTy, dstTy>::is_defined) - { - using ReductionOpT = sycl::multiplies; - return dpctl::tensor::kernels:: - reduction_axis1_over_group_with_atomics_contig_impl< - srcTy, dstTy, ReductionOpT>; - } - else { - return nullptr; - } - } -}; + resTy *temp_arg = partially_reduced_tmp; + resTy *temp2_arg = partially_reduced_tmp2; + sycl::event dependent_ev = first_reduction_ev; -template -struct ProductOverAxis0AtomicContigFactory -{ - fnT get() const - { - if constexpr (TypePairSupportDataForProductReductionAtomic< - srcTy, dstTy>::is_defined) - { - using ReductionOpT = sycl::multiplies; - return dpctl::tensor::kernels:: - reduction_axis0_over_group_with_atomics_contig_impl< - srcTy, dstTy, ReductionOpT>; - } - else { - return nullptr; - } - } -}; + while (remaining_reduction_nelems > + preferrered_reductions_per_wi * max_wg) { + size_t reduction_groups_ = + (remaining_reduction_nelems + + preferrered_reductions_per_wi * wg - 1) / + (preferrered_reductions_per_wi * wg); + assert(reduction_groups_ > 1); -// Argmax and Argmin + // keep reducing + sycl::event partial_reduction_ev = exec_q.submit([&](sycl::handler + &cgh) { + cgh.depends_on(dependent_ev); -/* = Search reduction using reduce_over_group*/ + using InputIndexerT = + dpctl::tensor::offset_utils::Strided1DIndexer; + using ResIndexerT = dpctl::tensor::offset_utils::NoOpIndexer; + using InputOutputIterIndexerT = + dpctl::tensor::offset_utils::TwoOffsets_CombinedIndexer< + InputIndexerT, ResIndexerT>; + using ReductionIndexerT = + dpctl::tensor::offset_utils::NoOpIndexer; -template -struct SearchReduction -{ -private: - const argT *inp_ = nullptr; - argT *vals_ = nullptr; - const outT *inds_ = nullptr; - outT *out_ = nullptr; - ReductionOp reduction_op_; - argT identity_; - IdxReductionOp idx_reduction_op_; - outT idx_identity_; - InputOutputIterIndexerT inp_out_iter_indexer_; - InputRedIndexerT inp_reduced_dims_indexer_; - size_t reduction_max_gid_ = 0; - size_t iter_gws_ = 1; - size_t reductions_per_wi = 16; + InputIndexerT inp_indexer{ + 0, static_cast(iter_nelems), + static_cast(reduction_groups_)}; + ResIndexerT res_iter_indexer{}; -public: - SearchReduction(const argT *data, - argT *vals, - const outT *inds, - outT *res, - ReductionOp reduction_op, - const argT &identity_val, - IdxReductionOp idx_reduction_op, - const outT &idx_identity_val, - InputOutputIterIndexerT arg_res_iter_indexer, - InputRedIndexerT arg_reduced_dims_indexer, - size_t reduction_size, - size_t iteration_size, - size_t reduction_size_per_wi) - : inp_(data), vals_(vals), inds_(inds), out_(res), - reduction_op_(reduction_op), identity_(identity_val), - idx_reduction_op_(idx_reduction_op), idx_identity_(idx_identity_val), - inp_out_iter_indexer_(arg_res_iter_indexer), - inp_reduced_dims_indexer_(arg_reduced_dims_indexer), - reduction_max_gid_(reduction_size), iter_gws_(iteration_size), - reductions_per_wi(reduction_size_per_wi) - { - } + InputOutputIterIndexerT in_out_iter_indexer{inp_indexer, + res_iter_indexer}; + ReductionIndexerT reduction_indexer{}; - void operator()(sycl::nd_item<1> it) const - { - const size_t reduction_lid = it.get_local_id(0); - const size_t wg = it.get_local_range(0); // 0 <= reduction_lid < wg + auto globalRange = + sycl::range<1>{iter_nelems * reduction_groups_ * wg}; + auto localRange = sycl::range<1>{wg}; + if constexpr (can_use_reduce_over_group::value) { + using KernelName = + class middle_reduction_axis0_temps_contig_krn< + resTy, resTy, ReductionOpT, InputOutputIterIndexerT, + ReductionIndexerT>; + cgh.parallel_for( + sycl::nd_range<1>(globalRange, localRange), + ReductionOverGroupNoAtomicFunctor< + resTy, resTy, ReductionOpT, InputOutputIterIndexerT, + ReductionIndexerT>( + temp_arg, temp2_arg, ReductionOpT(), identity_val, + in_out_iter_indexer, reduction_indexer, + remaining_reduction_nelems, iter_nelems, + preferrered_reductions_per_wi)); + } + else { + using SlmT = sycl::local_accessor; + SlmT local_memory = SlmT(localRange, cgh); + using KernelName = + class middle_custom_reduction_axis0_temps_contig_krn< + resTy, resTy, ReductionOpT, InputOutputIterIndexerT, + ReductionIndexerT, SlmT>; + cgh.parallel_for( + sycl::nd_range<1>(globalRange, localRange), + CustomReductionOverGroupNoAtomicFunctor< + resTy, resTy, ReductionOpT, InputOutputIterIndexerT, + ReductionIndexerT, SlmT>( + temp_arg, temp2_arg, ReductionOpT(), identity_val, + in_out_iter_indexer, reduction_indexer, + local_memory, remaining_reduction_nelems, + iter_nelems, preferrered_reductions_per_wi)); + } + }); - const size_t iter_gid = it.get_group(0) % iter_gws_; - const size_t reduction_batch_id = it.get_group(0) / iter_gws_; - const size_t n_reduction_groups = it.get_group_range(0) / iter_gws_; + remaining_reduction_nelems = reduction_groups_; + std::swap(temp_arg, temp2_arg); + dependent_ev = std::move(partial_reduction_ev); + } - // work-items operates over input with indices - // inp_data_id = reduction_batch_id * wg * reductions_per_wi + m * wg - // + reduction_lid - // for 0 <= m < reductions_per_wi + // final reduction to res + sycl::event final_reduction_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(dependent_ev); - auto inp_out_iter_offsets_ = inp_out_iter_indexer_(iter_gid); - const auto &inp_iter_offset = inp_out_iter_offsets_.get_first_offset(); - const auto &out_iter_offset = inp_out_iter_offsets_.get_second_offset(); + using InputIndexerT = dpctl::tensor::offset_utils::Strided1DIndexer; + using ResIndexerT = dpctl::tensor::offset_utils::NoOpIndexer; + using InputOutputIterIndexerT = + dpctl::tensor::offset_utils::TwoOffsets_CombinedIndexer< + InputIndexerT, ResIndexerT>; + using ReductionIndexerT = dpctl::tensor::offset_utils::NoOpIndexer; - argT local_red_val(identity_); - outT local_idx(idx_identity_); - size_t arg_reduce_gid0 = - reduction_lid + reduction_batch_id * wg * reductions_per_wi; - for (size_t m = 0; m < reductions_per_wi; ++m) { - size_t arg_reduce_gid = arg_reduce_gid0 + m * wg; + InputIndexerT inp_indexer{ + 0, static_cast(iter_nelems), + static_cast(remaining_reduction_nelems)}; + ResIndexerT res_iter_indexer{}; - if (arg_reduce_gid < reduction_max_gid_) { - auto inp_reduction_offset = - inp_reduced_dims_indexer_(arg_reduce_gid); - auto inp_offset = inp_iter_offset + inp_reduction_offset; + InputOutputIterIndexerT in_out_iter_indexer{inp_indexer, + res_iter_indexer}; + ReductionIndexerT reduction_indexer{}; - argT val = inp_[inp_offset]; - if (val == local_red_val) { - if constexpr (!First) { - local_idx = - idx_reduction_op_(local_idx, inds_[inp_offset]); - } - else { - local_idx = idx_reduction_op_( - local_idx, static_cast(arg_reduce_gid)); - } - } - else { - if constexpr (su_ns::IsMinimum::value) { - if (val < local_red_val) { - local_red_val = val; - if constexpr (!First) { - local_idx = inds_[inp_offset]; - } - else { - local_idx = static_cast(arg_reduce_gid); - } - } - } - else if constexpr (su_ns::IsMaximum::value) { - if (val > local_red_val) { - local_red_val = val; - if constexpr (!First) { - local_idx = inds_[inp_offset]; - } - else { - local_idx = static_cast(arg_reduce_gid); - } - } - } - } - } - } + wg = max_wg; + reductions_per_wi = + std::max(1, (remaining_reduction_nelems + wg - 1) / wg); - auto work_group = it.get_group(); - // This only works if reduction_op_ is from small set of operators - argT red_val_over_wg = sycl::reduce_over_group( - work_group, local_red_val, identity_, reduction_op_); + size_t reduction_groups = + (remaining_reduction_nelems + reductions_per_wi * wg - 1) / + (reductions_per_wi * wg); + assert(reduction_groups == 1); - if constexpr (std::is_integral_v) { - local_idx = - (red_val_over_wg == local_red_val) ? local_idx : idx_identity_; - } - else { - local_idx = - (red_val_over_wg == local_red_val || - std::isnan(red_val_over_wg) || std::isnan(local_red_val)) - ? local_idx - : idx_identity_; - } - outT idx_over_wg = sycl::reduce_over_group( - work_group, local_idx, idx_identity_, idx_reduction_op_); + auto globalRange = + sycl::range<1>{iter_nelems * reduction_groups * wg}; + auto localRange = sycl::range<1>{wg}; - if (work_group.leader()) { - // each group writes to a different memory location - if constexpr (!Last) { - // if not the final reduction, write value corresponding to - // an index to a temporary - vals_[out_iter_offset * n_reduction_groups + - reduction_batch_id] = red_val_over_wg; + if constexpr (can_use_reduce_over_group::value) + { + using KernelName = class final_reduction_axis0_temps_contig_krn< + argTy, resTy, ReductionOpT, InputOutputIterIndexerT, + ReductionIndexerT>; + cgh.parallel_for( + sycl::nd_range<1>(globalRange, localRange), + ReductionOverGroupNoAtomicFunctor< + resTy, resTy, ReductionOpT, InputOutputIterIndexerT, + ReductionIndexerT>(temp_arg, res_tp, ReductionOpT(), + identity_val, in_out_iter_indexer, + reduction_indexer, + remaining_reduction_nelems, + iter_nelems, reductions_per_wi)); } - out_[out_iter_offset * n_reduction_groups + reduction_batch_id] = - idx_over_wg; - } + else { + using SlmT = sycl::local_accessor; + SlmT local_memory = SlmT(localRange, cgh); + using KernelName = + class final_custom_reduction_axis0_temps_contig_krn< + argTy, resTy, ReductionOpT, InputOutputIterIndexerT, + ReductionIndexerT, SlmT>; + cgh.parallel_for( + sycl::nd_range<1>(globalRange, localRange), + CustomReductionOverGroupNoAtomicFunctor< + resTy, resTy, ReductionOpT, InputOutputIterIndexerT, + ReductionIndexerT, SlmT>( + temp_arg, res_tp, ReductionOpT(), identity_val, + in_out_iter_indexer, reduction_indexer, local_memory, + remaining_reduction_nelems, iter_nelems, + reductions_per_wi)); + } + }); + + sycl::event cleanup_host_task_event = + exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(final_reduction_ev); + const sycl::context &ctx = exec_q.get_context(); + + cgh.host_task([ctx, partially_reduced_tmp] { + sycl::free(partially_reduced_tmp, ctx); + }); + }); + + // FIXME: do not return host-task event + // Instead collect all host-tasks to a list + + return cleanup_host_task_event; } -}; +} -/* = Search reduction using custom_reduce_over_group*/ +/* @brief Types supported by comparison-reduction code based on atomic_ref */ +template +struct TypePairSupportDataForCompReductionAtomic +{ -template -struct CustomSearchReduction + /* value if true a kernel for must be instantiated, false + * otherwise */ + static constexpr bool is_defined = std::disjunction< // disjunction is C++17 + // feature, supported + // by DPC++ + // input int32 + td_ns::TypePairDefinedEntry, + // input uint32 + td_ns::TypePairDefinedEntry, + // input int64 + td_ns::TypePairDefinedEntry, + // input uint64 + td_ns::TypePairDefinedEntry, + // fall-through + td_ns::NotDefinedEntry>::is_defined; +}; + +template +struct TypePairSupportDataForCompReductionTemps { -private: - const argT *inp_ = nullptr; - argT *vals_ = nullptr; - const outT *inds_ = nullptr; - outT *out_ = nullptr; - ReductionOp reduction_op_; - argT identity_; - IdxReductionOp idx_reduction_op_; - outT idx_identity_; - InputOutputIterIndexerT inp_out_iter_indexer_; - InputRedIndexerT inp_reduced_dims_indexer_; - SlmT local_mem_; - size_t reduction_max_gid_ = 0; - size_t iter_gws_ = 1; - size_t reductions_per_wi = 16; -public: - CustomSearchReduction(const argT *data, - argT *vals, - outT *inds, - outT *res, - ReductionOp reduction_op, - const argT &identity_val, - IdxReductionOp idx_reduction_op, - const outT &idx_identity_val, - InputOutputIterIndexerT arg_res_iter_indexer, - InputRedIndexerT arg_reduced_dims_indexer, - SlmT local_mem, - size_t reduction_size, - size_t iteration_size, - size_t reduction_size_per_wi) - : inp_(data), vals_(vals), inds_(inds), out_(res), - reduction_op_(reduction_op), identity_(identity_val), - idx_reduction_op_(idx_reduction_op), idx_identity_(idx_identity_val), - inp_out_iter_indexer_(arg_res_iter_indexer), - inp_reduced_dims_indexer_(arg_reduced_dims_indexer), - local_mem_(local_mem), reduction_max_gid_(reduction_size), - iter_gws_(iteration_size), reductions_per_wi(reduction_size_per_wi) + static constexpr bool is_defined = std::disjunction< // disjunction is C++17 + // feature, supported + // by DPC++ input bool + td_ns::TypePairDefinedEntry, + // input int8_t + td_ns::TypePairDefinedEntry, + + // input uint8_t + td_ns::TypePairDefinedEntry, + + // input int16_t + td_ns::TypePairDefinedEntry, + + // input uint16_t + td_ns::TypePairDefinedEntry, + + // input int32_t + td_ns::TypePairDefinedEntry, + // input uint32_t + td_ns::TypePairDefinedEntry, + + // input int64_t + td_ns::TypePairDefinedEntry, + + // input uint32_t + td_ns::TypePairDefinedEntry, + + // input half + td_ns::TypePairDefinedEntry, + + // input float + td_ns::TypePairDefinedEntry, + + // input double + td_ns::TypePairDefinedEntry, + + // input std::complex + td_ns::TypePairDefinedEntry, + outTy, + std::complex>, + + td_ns::TypePairDefinedEntry, + outTy, + std::complex>, + + // fall-through + td_ns::NotDefinedEntry>::is_defined; +}; + +template +struct MaxOverAxisAtomicStridedFactory +{ + fnT get() const + { + if constexpr (TypePairSupportDataForCompReductionAtomic< + srcTy, dstTy>::is_defined) + { + if constexpr (std::is_floating_point::value) { + using ReductionOpT = su_ns::Maximum; + return dpctl::tensor::kernels:: + reduction_over_group_with_atomics_strided_impl< + srcTy, dstTy, ReductionOpT>; + } + else { + using ReductionOpT = sycl::maximum; + return dpctl::tensor::kernels:: + reduction_over_group_with_atomics_strided_impl< + srcTy, dstTy, ReductionOpT>; + } + } + else { + return nullptr; + } + } +}; + +template +struct MaxOverAxisTempsStridedFactory +{ + fnT get() const { + if constexpr (TypePairSupportDataForCompReductionTemps< + srcTy, dstTy>::is_defined) + { + if constexpr (std::is_integral_v && + !std::is_same_v) { + using ReductionOpT = sycl::maximum; + return dpctl::tensor::kernels:: + reduction_over_group_temps_strided_impl; + } + else { + using ReductionOpT = su_ns::Maximum; + return dpctl::tensor::kernels:: + reduction_over_group_temps_strided_impl; + } + } + else { + return nullptr; + } } +}; - void operator()(sycl::nd_item<1> it) const +template +struct MaxOverAxis1AtomicContigFactory +{ + fnT get() const { - const size_t reduction_lid = it.get_local_id(0); - const size_t wg = it.get_local_range(0); // 0 <= reduction_lid < wg + if constexpr (TypePairSupportDataForCompReductionAtomic< + srcTy, dstTy>::is_defined) + { + if constexpr (std::is_floating_point::value) { + using ReductionOpT = su_ns::Maximum; + return dpctl::tensor::kernels:: + reduction_axis1_over_group_with_atomics_contig_impl< + srcTy, dstTy, ReductionOpT>; + } + else { + using ReductionOpT = sycl::maximum; + return dpctl::tensor::kernels:: + reduction_axis1_over_group_with_atomics_contig_impl< + srcTy, dstTy, ReductionOpT>; + } + } + else { + return nullptr; + } + } +}; - const size_t iter_gid = it.get_group(0) % iter_gws_; - const size_t reduction_batch_id = it.get_group(0) / iter_gws_; - const size_t n_reduction_groups = it.get_group_range(0) / iter_gws_; +template +struct MaxOverAxis0AtomicContigFactory +{ + fnT get() const + { + if constexpr (TypePairSupportDataForCompReductionAtomic< + srcTy, dstTy>::is_defined) + { + if constexpr (std::is_floating_point::value) { + using ReductionOpT = su_ns::Maximum; + return dpctl::tensor::kernels:: + reduction_axis0_over_group_with_atomics_contig_impl< + srcTy, dstTy, ReductionOpT>; + } + else { + using ReductionOpT = sycl::maximum; + return dpctl::tensor::kernels:: + reduction_axis0_over_group_with_atomics_contig_impl< + srcTy, dstTy, ReductionOpT>; + } + } + else { + return nullptr; + } + } +}; + +template +struct MaxOverAxis1TempsContigFactory +{ + fnT get() const + { + if constexpr (TypePairSupportDataForCompReductionTemps< + srcTy, dstTy>::is_defined) + { + if constexpr (std::is_integral_v && + !std::is_same_v) { + using ReductionOpT = sycl::maximum; + return dpctl::tensor::kernels:: + reduction_axis1_over_group_temps_contig_impl; + } + else { + using ReductionOpT = su_ns::Maximum; + return dpctl::tensor::kernels:: + reduction_axis1_over_group_temps_contig_impl; + } + } + else { + return nullptr; + } + } +}; + +template +struct MaxOverAxis0TempsContigFactory +{ + fnT get() const + { + if constexpr (TypePairSupportDataForCompReductionTemps< + srcTy, dstTy>::is_defined) + { + if constexpr (std::is_integral_v && + !std::is_same_v) { + using ReductionOpT = sycl::maximum; + return dpctl::tensor::kernels:: + reduction_axis0_over_group_temps_contig_impl; + } + else { + using ReductionOpT = su_ns::Maximum; + return dpctl::tensor::kernels:: + reduction_axis0_over_group_temps_contig_impl; + } + } + else { + return nullptr; + } + } +}; + +template +struct MinOverAxisAtomicStridedFactory +{ + fnT get() const + { + if constexpr (TypePairSupportDataForCompReductionAtomic< + srcTy, dstTy>::is_defined) + { + if constexpr (std::is_floating_point::value) { + using ReductionOpT = su_ns::Minimum; + return dpctl::tensor::kernels:: + reduction_over_group_with_atomics_strided_impl< + srcTy, dstTy, ReductionOpT>; + } + else { + using ReductionOpT = sycl::minimum; + return dpctl::tensor::kernels:: + reduction_over_group_with_atomics_strided_impl< + srcTy, dstTy, ReductionOpT>; + } + } + else { + return nullptr; + } + } +}; + +template +struct MinOverAxisTempsStridedFactory +{ + fnT get() const + { + if constexpr (TypePairSupportDataForCompReductionTemps< + srcTy, dstTy>::is_defined) + { + if constexpr (std::is_integral_v && + !std::is_same_v) { + using ReductionOpT = sycl::minimum; + return dpctl::tensor::kernels:: + reduction_over_group_temps_strided_impl; + } + else { + using ReductionOpT = su_ns::Minimum; + return dpctl::tensor::kernels:: + reduction_over_group_temps_strided_impl; + } + } + else { + return nullptr; + } + } +}; + +template +struct MinOverAxis1AtomicContigFactory +{ + fnT get() const + { + if constexpr (TypePairSupportDataForCompReductionAtomic< + srcTy, dstTy>::is_defined) + { + if constexpr (std::is_floating_point::value) { + using ReductionOpT = su_ns::Minimum; + return dpctl::tensor::kernels:: + reduction_axis1_over_group_with_atomics_contig_impl< + srcTy, dstTy, ReductionOpT>; + } + else { + using ReductionOpT = sycl::minimum; + return dpctl::tensor::kernels:: + reduction_axis1_over_group_with_atomics_contig_impl< + srcTy, dstTy, ReductionOpT>; + } + } + else { + return nullptr; + } + } +}; + +template +struct MinOverAxis0AtomicContigFactory +{ + fnT get() const + { + if constexpr (TypePairSupportDataForCompReductionAtomic< + srcTy, dstTy>::is_defined) + { + if constexpr (std::is_floating_point::value) { + using ReductionOpT = su_ns::Minimum; + return dpctl::tensor::kernels:: + reduction_axis0_over_group_with_atomics_contig_impl< + srcTy, dstTy, ReductionOpT>; + } + else { + using ReductionOpT = sycl::minimum; + return dpctl::tensor::kernels:: + reduction_axis0_over_group_with_atomics_contig_impl< + srcTy, dstTy, ReductionOpT>; + } + } + else { + return nullptr; + } + } +}; + +template +struct MinOverAxis1TempsContigFactory +{ + fnT get() const + { + if constexpr (TypePairSupportDataForCompReductionTemps< + srcTy, dstTy>::is_defined) + { + if constexpr (std::is_integral_v && + !std::is_same_v) { + using ReductionOpT = sycl::minimum; + return dpctl::tensor::kernels:: + reduction_axis1_over_group_temps_contig_impl; + } + else { + using ReductionOpT = su_ns::Minimum; + return dpctl::tensor::kernels:: + reduction_axis1_over_group_temps_contig_impl; + } + } + else { + return nullptr; + } + } +}; + +template +struct MinOverAxis0TempsContigFactory +{ + fnT get() const + { + if constexpr (TypePairSupportDataForCompReductionTemps< + srcTy, dstTy>::is_defined) + { + if constexpr (std::is_integral_v && + !std::is_same_v) { + using ReductionOpT = sycl::minimum; + return dpctl::tensor::kernels:: + reduction_axis0_over_group_temps_contig_impl; + } + else { + using ReductionOpT = su_ns::Minimum; + return dpctl::tensor::kernels:: + reduction_axis0_over_group_temps_contig_impl; + } + } + else { + return nullptr; + } + } +}; + +// Sum + +/* @brief Types supported by plus-reduction code based on atomic_ref */ +template +struct TypePairSupportDataForSumReductionAtomic +{ + + /* value if true a kernel for must be instantiated, false + * otherwise */ + static constexpr bool is_defined = std::disjunction< // disjunction is C++17 + // feature, supported + // by DPC++ input bool + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + // input int8 + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + // input uint8 + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + // input int16 + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + // input uint16 + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + // input int32 + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + // input uint32 + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + // input int64 + td_ns::TypePairDefinedEntry, + // input uint64 + td_ns::TypePairDefinedEntry, + // fall-through + td_ns::NotDefinedEntry>::is_defined; +}; + +template +struct TypePairSupportDataForSumReductionTemps +{ + + static constexpr bool is_defined = std::disjunction< // disjunction is C++17 + // feature, supported + // by DPC++ input bool + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + + // input int8_t + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + + // input uint8_t + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + + // input int16_t + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + + // input uint16_t + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + + // input int32_t + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + + // input uint32_t + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + + // input int64_t + td_ns::TypePairDefinedEntry, + + // input uint32_t + td_ns::TypePairDefinedEntry, + + // input half + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns:: + TypePairDefinedEntry>, + td_ns::TypePairDefinedEntry>, + + // input float + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry>, + td_ns::TypePairDefinedEntry>, + + // input double + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry>, + + // input std::complex + td_ns::TypePairDefinedEntry, + outTy, + std::complex>, + td_ns::TypePairDefinedEntry, + outTy, + std::complex>, + + td_ns::TypePairDefinedEntry, + outTy, + std::complex>, + + // fall-throug + td_ns::NotDefinedEntry>::is_defined; +}; + +template +struct SumOverAxisAtomicStridedFactory +{ + fnT get() const + { + if constexpr (TypePairSupportDataForSumReductionAtomic< + srcTy, dstTy>::is_defined) + { + using ReductionOpT = sycl::plus; + return dpctl::tensor::kernels:: + reduction_over_group_with_atomics_strided_impl; + } + else { + return nullptr; + } + } +}; + +template +struct SumOverAxisTempsStridedFactory +{ + fnT get() const + { + if constexpr (TypePairSupportDataForSumReductionTemps< + srcTy, dstTy>::is_defined) { + using ReductionOpT = sycl::plus; + return dpctl::tensor::kernels:: + reduction_over_group_temps_strided_impl; + } + else { + return nullptr; + } + } +}; + +template +struct SumOverAxis1AtomicContigFactory +{ + fnT get() const + { + if constexpr (TypePairSupportDataForSumReductionAtomic< + srcTy, dstTy>::is_defined) + { + using ReductionOpT = sycl::plus; + return dpctl::tensor::kernels:: + reduction_axis1_over_group_with_atomics_contig_impl< + srcTy, dstTy, ReductionOpT>; + } + else { + return nullptr; + } + } +}; + +template +struct SumOverAxis0AtomicContigFactory +{ + fnT get() const + { + if constexpr (TypePairSupportDataForSumReductionAtomic< + srcTy, dstTy>::is_defined) + { + using ReductionOpT = sycl::plus; + return dpctl::tensor::kernels:: + reduction_axis0_over_group_with_atomics_contig_impl< + srcTy, dstTy, ReductionOpT>; + } + else { + return nullptr; + } + } +}; + +template +struct SumOverAxis1TempsContigFactory +{ + fnT get() const + { + if constexpr (TypePairSupportDataForSumReductionTemps< + srcTy, dstTy>::is_defined) { + using ReductionOpT = sycl::plus; + return dpctl::tensor::kernels:: + reduction_axis1_over_group_temps_contig_impl; + } + else { + return nullptr; + } + } +}; + +template +struct SumOverAxis0TempsContigFactory +{ + fnT get() const + { + if constexpr (TypePairSupportDataForSumReductionTemps< + srcTy, dstTy>::is_defined) { + using ReductionOpT = sycl::plus; + return dpctl::tensor::kernels:: + reduction_axis0_over_group_temps_contig_impl; + } + else { + return nullptr; + } + } +}; + +// Product + +/* @brief Types supported by plus-reduction code based on atomic_ref */ +template +struct TypePairSupportDataForProductReductionAtomic +{ + + /* value if true a kernel for must be instantiated, false + * otherwise */ + static constexpr bool is_defined = std::disjunction< // disjunction is C++17 + // feature, supported + // by DPC++ input bool + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + // input int8 + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + // input uint8 + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + // input int16 + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + // input uint16 + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + // input int32 + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + // input uint32 + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + // input int64 + td_ns::TypePairDefinedEntry, + // input uint64 + td_ns::TypePairDefinedEntry, + // fall-through + td_ns::NotDefinedEntry>::is_defined; +}; + +template +struct TypePairSupportDataForProductReductionTemps +{ + + static constexpr bool is_defined = std::disjunction< // disjunction is C++17 + // feature, supported + // by DPC++ input bool + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + + // input int8_t + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + + // input uint8_t + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + + // input int16_t + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + + // input uint16_t + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + + // input int32_t + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + + // input uint32_t + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + + // input int64_t + td_ns::TypePairDefinedEntry, + + // input uint32_t + td_ns::TypePairDefinedEntry, + + // input half + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns:: + TypePairDefinedEntry>, + td_ns::TypePairDefinedEntry>, + + // input float + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry>, + td_ns::TypePairDefinedEntry>, + + // input double + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry>, + + // input std::complex + td_ns::TypePairDefinedEntry, + outTy, + std::complex>, + td_ns::TypePairDefinedEntry, + outTy, + std::complex>, + + td_ns::TypePairDefinedEntry, + outTy, + std::complex>, + + // fall-throug + td_ns::NotDefinedEntry>::is_defined; +}; + +template +struct ProductOverAxisAtomicStridedFactory +{ + fnT get() const + { + if constexpr (TypePairSupportDataForProductReductionAtomic< + srcTy, dstTy>::is_defined) + { + using ReductionOpT = sycl::multiplies; + return dpctl::tensor::kernels:: + reduction_over_group_with_atomics_strided_impl; + } + else { + return nullptr; + } + } +}; + +template +struct ProductOverAxisTempsStridedFactory +{ + fnT get() const + { + if constexpr (TypePairSupportDataForProductReductionTemps< + srcTy, dstTy>::is_defined) + { + using ReductionOpT = sycl::multiplies; + return dpctl::tensor::kernels:: + reduction_over_group_temps_strided_impl; + } + else { + return nullptr; + } + } +}; + +template +struct ProductOverAxis1AtomicContigFactory +{ + fnT get() const + { + if constexpr (TypePairSupportDataForProductReductionAtomic< + srcTy, dstTy>::is_defined) + { + using ReductionOpT = sycl::multiplies; + return dpctl::tensor::kernels:: + reduction_axis1_over_group_with_atomics_contig_impl< + srcTy, dstTy, ReductionOpT>; + } + else { + return nullptr; + } + } +}; + +template +struct ProductOverAxis0AtomicContigFactory +{ + fnT get() const + { + if constexpr (TypePairSupportDataForProductReductionAtomic< + srcTy, dstTy>::is_defined) + { + using ReductionOpT = sycl::multiplies; + return dpctl::tensor::kernels:: + reduction_axis0_over_group_with_atomics_contig_impl< + srcTy, dstTy, ReductionOpT>; + } + else { + return nullptr; + } + } +}; + +template +struct ProductOverAxis1TempsContigFactory +{ + fnT get() const + { + if constexpr (TypePairSupportDataForProductReductionTemps< + srcTy, dstTy>::is_defined) + { + using ReductionOpT = sycl::multiplies; + return dpctl::tensor::kernels:: + reduction_axis1_over_group_temps_contig_impl; + } + else { + return nullptr; + } + } +}; + +template +struct ProductOverAxis0TempsContigFactory +{ + fnT get() const + { + if constexpr (TypePairSupportDataForProductReductionTemps< + srcTy, dstTy>::is_defined) + { + using ReductionOpT = sycl::multiplies; + return dpctl::tensor::kernels:: + reduction_axis0_over_group_temps_contig_impl; + } + else { + return nullptr; + } + } +}; + +template +struct TypePairSupportDataForHypotReductionTemps +{ + + static constexpr bool is_defined = std::disjunction< // disjunction is C++17 + // feature, supported + // by DPC++ input bool + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + + // input int8_t + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + + // input uint8_t + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + + // input int16_t + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + + // input uint16_t + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + + // input int32_t + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + + // input uint32_t + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + + // input int64_t + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + + // input uint64_t + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + + // input half + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + + // input float + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + + // input double + td_ns::TypePairDefinedEntry, + + // fall-through + td_ns::NotDefinedEntry>::is_defined; +}; + +template +struct HypotOverAxisTempsStridedFactory +{ + fnT get() const + { + if constexpr (TypePairSupportDataForHypotReductionTemps< + srcTy, dstTy>::is_defined) + { + using ReductionOpT = su_ns::Hypot; + return dpctl::tensor::kernels:: + reduction_over_group_temps_strided_impl; + } + else { + return nullptr; + } + } +}; + +template +struct HypotOverAxis1TempsContigFactory +{ + fnT get() const + { + if constexpr (TypePairSupportDataForHypotReductionTemps< + srcTy, dstTy>::is_defined) + { + using ReductionOpT = su_ns::Hypot; + return dpctl::tensor::kernels:: + reduction_axis1_over_group_temps_contig_impl; + } + else { + return nullptr; + } + } +}; + +template +struct HypotOverAxis0TempsContigFactory +{ + fnT get() const + { + if constexpr (TypePairSupportDataForHypotReductionTemps< + srcTy, dstTy>::is_defined) + { + using ReductionOpT = su_ns::Hypot; + return dpctl::tensor::kernels:: + reduction_axis0_over_group_temps_contig_impl; + } + else { + return nullptr; + } + } +}; + +template +struct TypePairSupportDataForLogSumExpReductionTemps +{ + + static constexpr bool is_defined = std::disjunction< // disjunction is C++17 + // feature, supported + // by DPC++ input bool + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + + // input int8_t + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + + // input uint8_t + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + + // input int16_t + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + + // input uint16_t + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + + // input int32_t + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + + // input uint32_t + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + + // input int64_t + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + + // input uint64_t + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + + // input half + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + + // input float + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + + // input double + td_ns::TypePairDefinedEntry, + + // fall-through + td_ns::NotDefinedEntry>::is_defined; +}; + +template +struct LogSumExpOverAxisTempsStridedFactory +{ + fnT get() const + { + if constexpr (TypePairSupportDataForLogSumExpReductionTemps< + srcTy, dstTy>::is_defined) + { + using ReductionOpT = su_ns::LogSumExp; + return dpctl::tensor::kernels:: + reduction_over_group_temps_strided_impl; + } + else { + return nullptr; + } + } +}; + +template +struct LogSumExpOverAxis1TempsContigFactory +{ + fnT get() const + { + if constexpr (TypePairSupportDataForLogSumExpReductionTemps< + srcTy, dstTy>::is_defined) + { + using ReductionOpT = su_ns::LogSumExp; + return dpctl::tensor::kernels:: + reduction_axis1_over_group_temps_contig_impl; + } + else { + return nullptr; + } + } +}; + +template +struct LogSumExpOverAxis0TempsContigFactory +{ + fnT get() const + { + if constexpr (TypePairSupportDataForLogSumExpReductionTemps< + srcTy, dstTy>::is_defined) + { + using ReductionOpT = su_ns::LogSumExp; + return dpctl::tensor::kernels:: + reduction_axis0_over_group_temps_contig_impl; + } + else { + return nullptr; + } + } +}; + +// Argmax and Argmin + +/* = Search reduction using reduce_over_group*/ + +template +struct SearchReduction +{ +private: + const argT *inp_ = nullptr; + argT *vals_ = nullptr; + const outT *inds_ = nullptr; + outT *out_ = nullptr; + ReductionOp reduction_op_; + argT identity_; + IdxReductionOp idx_reduction_op_; + outT idx_identity_; + InputOutputIterIndexerT inp_out_iter_indexer_; + InputRedIndexerT inp_reduced_dims_indexer_; + size_t reduction_max_gid_ = 0; + size_t iter_gws_ = 1; + size_t reductions_per_wi = 16; + +public: + SearchReduction(const argT *data, + argT *vals, + const outT *inds, + outT *res, + ReductionOp reduction_op, + const argT &identity_val, + IdxReductionOp idx_reduction_op, + const outT &idx_identity_val, + InputOutputIterIndexerT arg_res_iter_indexer, + InputRedIndexerT arg_reduced_dims_indexer, + size_t reduction_size, + size_t iteration_size, + size_t reduction_size_per_wi) + : inp_(data), vals_(vals), inds_(inds), out_(res), + reduction_op_(reduction_op), identity_(identity_val), + idx_reduction_op_(idx_reduction_op), idx_identity_(idx_identity_val), + inp_out_iter_indexer_(arg_res_iter_indexer), + inp_reduced_dims_indexer_(arg_reduced_dims_indexer), + reduction_max_gid_(reduction_size), iter_gws_(iteration_size), + reductions_per_wi(reduction_size_per_wi) + { + } + + void operator()(sycl::nd_item<1> it) const + { + const size_t reduction_lid = it.get_local_id(0); + const size_t wg = it.get_local_range(0); // 0 <= reduction_lid < wg + + const size_t iter_gid = it.get_group(0) % iter_gws_; + const size_t reduction_batch_id = it.get_group(0) / iter_gws_; + const size_t n_reduction_groups = it.get_group_range(0) / iter_gws_; + + // work-items operates over input with indices + // inp_data_id = reduction_batch_id * wg * reductions_per_wi + m * wg + // + reduction_lid + // for 0 <= m < reductions_per_wi + + auto inp_out_iter_offsets_ = inp_out_iter_indexer_(iter_gid); + const auto &inp_iter_offset = inp_out_iter_offsets_.get_first_offset(); + const auto &out_iter_offset = inp_out_iter_offsets_.get_second_offset(); + + argT local_red_val(identity_); + outT local_idx(idx_identity_); + size_t arg_reduce_gid0 = + reduction_lid + reduction_batch_id * wg * reductions_per_wi; + for (size_t m = 0; m < reductions_per_wi; ++m) { + size_t arg_reduce_gid = arg_reduce_gid0 + m * wg; + + if (arg_reduce_gid < reduction_max_gid_) { + auto inp_reduction_offset = + inp_reduced_dims_indexer_(arg_reduce_gid); + auto inp_offset = inp_iter_offset + inp_reduction_offset; + + argT val = inp_[inp_offset]; + if (val == local_red_val) { + if constexpr (!First) { + local_idx = + idx_reduction_op_(local_idx, inds_[inp_offset]); + } + else { + local_idx = idx_reduction_op_( + local_idx, static_cast(arg_reduce_gid)); + } + } + else { + if constexpr (su_ns::IsMinimum::value) { + if (val < local_red_val) { + local_red_val = val; + if constexpr (!First) { + local_idx = inds_[inp_offset]; + } + else { + local_idx = static_cast(arg_reduce_gid); + } + } + } + else if constexpr (su_ns::IsMaximum::value) { + if (val > local_red_val) { + local_red_val = val; + if constexpr (!First) { + local_idx = inds_[inp_offset]; + } + else { + local_idx = static_cast(arg_reduce_gid); + } + } + } + } + } + } + + auto work_group = it.get_group(); + // This only works if reduction_op_ is from small set of operators + argT red_val_over_wg = sycl::reduce_over_group( + work_group, local_red_val, identity_, reduction_op_); + + if constexpr (std::is_integral_v) { + local_idx = + (red_val_over_wg == local_red_val) ? local_idx : idx_identity_; + } + else { + local_idx = + (red_val_over_wg == local_red_val || + std::isnan(red_val_over_wg) || std::isnan(local_red_val)) + ? local_idx + : idx_identity_; + } + outT idx_over_wg = sycl::reduce_over_group( + work_group, local_idx, idx_identity_, idx_reduction_op_); + + if (work_group.leader()) { + // each group writes to a different memory location + if constexpr (!Last) { + // if not the final reduction, write value corresponding to + // an index to a temporary + vals_[out_iter_offset * n_reduction_groups + + reduction_batch_id] = red_val_over_wg; + } + out_[out_iter_offset * n_reduction_groups + reduction_batch_id] = + idx_over_wg; + } + } +}; + +/* = Search reduction using custom_reduce_over_group*/ + +template +struct CustomSearchReduction +{ +private: + const argT *inp_ = nullptr; + argT *vals_ = nullptr; + const outT *inds_ = nullptr; + outT *out_ = nullptr; + ReductionOp reduction_op_; + argT identity_; + IdxReductionOp idx_reduction_op_; + outT idx_identity_; + InputOutputIterIndexerT inp_out_iter_indexer_; + InputRedIndexerT inp_reduced_dims_indexer_; + SlmT local_mem_; + size_t reduction_max_gid_ = 0; + size_t iter_gws_ = 1; + size_t reductions_per_wi = 16; + +public: + CustomSearchReduction(const argT *data, + argT *vals, + outT *inds, + outT *res, + ReductionOp reduction_op, + const argT &identity_val, + IdxReductionOp idx_reduction_op, + const outT &idx_identity_val, + InputOutputIterIndexerT arg_res_iter_indexer, + InputRedIndexerT arg_reduced_dims_indexer, + SlmT local_mem, + size_t reduction_size, + size_t iteration_size, + size_t reduction_size_per_wi) + : inp_(data), vals_(vals), inds_(inds), out_(res), + reduction_op_(reduction_op), identity_(identity_val), + idx_reduction_op_(idx_reduction_op), idx_identity_(idx_identity_val), + inp_out_iter_indexer_(arg_res_iter_indexer), + inp_reduced_dims_indexer_(arg_reduced_dims_indexer), + local_mem_(local_mem), reduction_max_gid_(reduction_size), + iter_gws_(iteration_size), reductions_per_wi(reduction_size_per_wi) + { + } + + void operator()(sycl::nd_item<1> it) const + { + const size_t reduction_lid = it.get_local_id(0); + const size_t wg = it.get_local_range(0); // 0 <= reduction_lid < wg + + const size_t iter_gid = it.get_group(0) % iter_gws_; + const size_t reduction_batch_id = it.get_group(0) / iter_gws_; + const size_t n_reduction_groups = it.get_group_range(0) / iter_gws_; + + // work-items operates over input with indices + // inp_data_id = reduction_batch_id * wg * reductions_per_wi + m * wg + // + reduction_lid + // for 0 <= m < reductions_per_wi + + auto inp_out_iter_offsets_ = inp_out_iter_indexer_(iter_gid); + const auto &inp_iter_offset = inp_out_iter_offsets_.get_first_offset(); + const auto &out_iter_offset = inp_out_iter_offsets_.get_second_offset(); + + argT local_red_val(identity_); + outT local_idx(idx_identity_); + size_t arg_reduce_gid0 = + reduction_lid + reduction_batch_id * wg * reductions_per_wi; + for (size_t m = 0; m < reductions_per_wi; ++m) { + size_t arg_reduce_gid = arg_reduce_gid0 + m * wg; + + if (arg_reduce_gid < reduction_max_gid_) { + auto inp_reduction_offset = + inp_reduced_dims_indexer_(arg_reduce_gid); + auto inp_offset = inp_iter_offset + inp_reduction_offset; + + argT val = inp_[inp_offset]; + if (val == local_red_val) { + if constexpr (!First) { + local_idx = + idx_reduction_op_(local_idx, inds_[inp_offset]); + } + else { + local_idx = idx_reduction_op_( + local_idx, static_cast(arg_reduce_gid)); + } + } + else { + if constexpr (su_ns::IsMinimum::value) { + using dpctl::tensor::type_utils::is_complex; + if constexpr (is_complex::value) { + using dpctl::tensor::math_utils::less_complex; + // less_complex always returns false for NaNs, so + // check + if (less_complex(val, local_red_val) || + std::isnan(std::real(val)) || + std::isnan(std::imag(val))) + { + local_red_val = val; + if constexpr (!First) { + local_idx = inds_[inp_offset]; + } + else { + local_idx = + static_cast(arg_reduce_gid); + } + } + } + else if constexpr (std::is_floating_point_v) { + if (val < local_red_val || std::isnan(val)) { + local_red_val = val; + if constexpr (!First) { + local_idx = inds_[inp_offset]; + } + else { + local_idx = + static_cast(arg_reduce_gid); + } + } + } + else { + if (val < local_red_val) { + local_red_val = val; + if constexpr (!First) { + local_idx = inds_[inp_offset]; + } + else { + local_idx = + static_cast(arg_reduce_gid); + } + } + } + } + else if constexpr (su_ns::IsMaximum::value) { + using dpctl::tensor::type_utils::is_complex; + if constexpr (is_complex::value) { + using dpctl::tensor::math_utils::greater_complex; + if (greater_complex(val, local_red_val) || + std::isnan(std::real(val)) || + std::isnan(std::imag(val))) + { + local_red_val = val; + if constexpr (!First) { + local_idx = inds_[inp_offset]; + } + else { + local_idx = + static_cast(arg_reduce_gid); + } + } + } + else if constexpr (std::is_floating_point_v) { + if (val > local_red_val || std::isnan(val)) { + local_red_val = val; + if constexpr (!First) { + local_idx = inds_[inp_offset]; + } + else { + local_idx = + static_cast(arg_reduce_gid); + } + } + } + else { + if (val > local_red_val) { + local_red_val = val; + if constexpr (!First) { + local_idx = inds_[inp_offset]; + } + else { + local_idx = + static_cast(arg_reduce_gid); + } + } + } + } + } + } + } + + auto work_group = it.get_group(); + // This only works if reduction_op_ is from small set of operators + argT red_val_over_wg = su_ns::custom_reduce_over_group( + work_group, local_mem_, local_red_val, reduction_op_); + + using dpctl::tensor::type_utils::is_complex; + if constexpr (is_complex::value) { + // equality does not hold for NaNs, so check here + local_idx = (red_val_over_wg == local_red_val || + std::isnan(std::real(local_red_val)) || + std::isnan(std::imag(local_red_val))) + ? local_idx + : idx_identity_; + } + else if constexpr (std::is_floating_point_v) { + // equality does not hold for NaNs, so check here + local_idx = + (red_val_over_wg == local_red_val || std::isnan(local_red_val)) + ? local_idx + : idx_identity_; + } + else { + local_idx = + red_val_over_wg == local_red_val ? local_idx : idx_identity_; + } + outT idx_over_wg = sycl::reduce_over_group( + work_group, local_idx, idx_identity_, idx_reduction_op_); + if (work_group.leader()) { + // each group writes to a different memory location + if constexpr (!Last) { + // if not the final reduction, write value corresponding to + // an index to a temporary + vals_[out_iter_offset * n_reduction_groups + + reduction_batch_id] = red_val_over_wg; + } + out_[out_iter_offset * n_reduction_groups + reduction_batch_id] = + idx_over_wg; + } + } +}; + +typedef sycl::event (*search_strided_impl_fn_ptr)( + sycl::queue, + size_t, + size_t, + const char *, + char *, + int, + const py::ssize_t *, + py::ssize_t, + py::ssize_t, + int, + const py::ssize_t *, + py::ssize_t, + const std::vector &); + +template +class search_over_group_temps_strided_krn; + +template +class custom_search_over_group_temps_strided_krn; + +template +class single_search_axis0_temps_contig_krn; + +template +class first_search_axis0_temps_contig_krn; + +template +class middle_search_axis0_temps_contig_krn; + +template +class final_search_axis0_temps_contig_krn; + +template +class single_custom_search_axis0_temps_contig_krn; + +template +class first_custom_search_axis0_temps_contig_krn; + +template +class middle_custom_search_axis0_temps_contig_krn; + +template +class final_custom_search_axis0_temps_contig_krn; + +template +class single_search_axis1_temps_contig_krn; + +template +class first_search_axis1_temps_contig_krn; + +template +class middle_search_axis1_temps_contig_krn; + +template +class final_search_axis1_temps_contig_krn; + +template +class single_custom_search_axis1_temps_contig_krn; + +template +class first_custom_search_axis1_temps_contig_krn; + +template +class middle_custom_search_axis1_temps_contig_krn; + +template +class final_custom_search_axis1_temps_contig_krn; + +template +sycl::event search_over_group_temps_strided_impl( + sycl::queue exec_q, + size_t iter_nelems, // number of reductions (num. of rows in a matrix + // when reducing over rows) + size_t reduction_nelems, // size of each reduction (length of rows, i.e. + // number of columns) + const char *arg_cp, + char *res_cp, + int iter_nd, + const py::ssize_t *iter_shape_and_strides, + py::ssize_t iter_arg_offset, + py::ssize_t iter_res_offset, + int red_nd, + const py::ssize_t *reduction_shape_stride, + py::ssize_t reduction_arg_offset, + const std::vector &depends) +{ + const argTy *arg_tp = reinterpret_cast(arg_cp); + resTy *res_tp = reinterpret_cast(res_cp); + + constexpr argTy identity_val = su_ns::Identity::value; + constexpr resTy idx_identity_val = su_ns::Identity::value; + + const sycl::device &d = exec_q.get_device(); + const auto &sg_sizes = d.get_info(); + size_t wg = choose_workgroup_size<4>(reduction_nelems, sg_sizes); + + constexpr size_t preferrered_reductions_per_wi = 4; + // max_max_wg prevents running out of resources on CPU + size_t max_wg = std::min( + size_t(2048), d.get_info()); + + size_t reductions_per_wi(preferrered_reductions_per_wi); + if (reduction_nelems <= preferrered_reductions_per_wi * max_wg) { + // reduction only requries 1 work-group, can output directly to res + sycl::event comp_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(depends); + + using InputOutputIterIndexerT = + dpctl::tensor::offset_utils::TwoOffsets_StridedIndexer; + using ReductionIndexerT = + dpctl::tensor::offset_utils::StridedIndexer; + + InputOutputIterIndexerT in_out_iter_indexer{ + iter_nd, iter_arg_offset, iter_res_offset, + iter_shape_and_strides}; + ReductionIndexerT reduction_indexer{red_nd, reduction_arg_offset, + reduction_shape_stride}; + + wg = max_wg; + reductions_per_wi = + std::max(1, (reduction_nelems + wg - 1) / wg); + + size_t reduction_groups = + (reduction_nelems + reductions_per_wi * wg - 1) / + (reductions_per_wi * wg); + assert(reduction_groups == 1); + + auto globalRange = + sycl::range<1>{iter_nelems * reduction_groups * wg}; + auto localRange = sycl::range<1>{wg}; + + if constexpr (can_use_reduce_over_group::value) + { + using KernelName = class search_over_group_temps_strided_krn< + argTy, resTy, ReductionOpT, IndexOpT, + InputOutputIterIndexerT, ReductionIndexerT, true, true>; + cgh.parallel_for( + sycl::nd_range<1>(globalRange, localRange), + SearchReduction( + arg_tp, nullptr, nullptr, res_tp, ReductionOpT(), + identity_val, IndexOpT(), idx_identity_val, + in_out_iter_indexer, reduction_indexer, + reduction_nelems, iter_nelems, reductions_per_wi)); + } + else { + using SlmT = sycl::local_accessor; + SlmT local_memory = SlmT(localRange, cgh); + using KernelName = + class custom_search_over_group_temps_strided_krn< + argTy, resTy, ReductionOpT, IndexOpT, + InputOutputIterIndexerT, ReductionIndexerT, SlmT, true, + true>; + cgh.parallel_for( + sycl::nd_range<1>(globalRange, localRange), + CustomSearchReduction( + arg_tp, nullptr, nullptr, res_tp, ReductionOpT(), + identity_val, IndexOpT(), idx_identity_val, + in_out_iter_indexer, reduction_indexer, local_memory, + reduction_nelems, iter_nelems, reductions_per_wi)); + } + }); + return comp_ev; + } + else { + // more than one work-groups is needed, requires a temporary + size_t reduction_groups = + (reduction_nelems + preferrered_reductions_per_wi * wg - 1) / + (preferrered_reductions_per_wi * wg); + assert(reduction_groups > 1); + + size_t second_iter_reduction_groups_ = + (reduction_groups + preferrered_reductions_per_wi * wg - 1) / + (preferrered_reductions_per_wi * wg); + + resTy *partially_reduced_tmp = sycl::malloc_device( + iter_nelems * (reduction_groups + second_iter_reduction_groups_), + exec_q); + resTy *partially_reduced_tmp2 = nullptr; + + if (partially_reduced_tmp == nullptr) { + throw std::runtime_error("Unable to allocate device_memory"); + } + else { + partially_reduced_tmp2 = + partially_reduced_tmp + reduction_groups * iter_nelems; + } + + argTy *partially_reduced_vals_tmp = sycl::malloc_device( + iter_nelems * (reduction_groups + second_iter_reduction_groups_), + exec_q); + argTy *partially_reduced_vals_tmp2 = nullptr; + + if (partially_reduced_vals_tmp == nullptr) { + throw std::runtime_error("Unable to allocate device_memory"); + } + else { + partially_reduced_vals_tmp2 = + partially_reduced_vals_tmp + reduction_groups * iter_nelems; + } + + sycl::event first_reduction_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(depends); + + using InputIndexerT = dpctl::tensor::offset_utils::StridedIndexer; + using ResIndexerT = dpctl::tensor::offset_utils::NoOpIndexer; + using InputOutputIterIndexerT = + dpctl::tensor::offset_utils::TwoOffsets_CombinedIndexer< + InputIndexerT, ResIndexerT>; + using ReductionIndexerT = + dpctl::tensor::offset_utils::StridedIndexer; + + // Only 2*iter_nd entries describing shape and strides of iterated + // dimensions of input array from iter_shape_and_strides are going + // to be accessed by inp_indexer + InputIndexerT inp_indexer(iter_nd, iter_arg_offset, + iter_shape_and_strides); + ResIndexerT noop_tmp_indexer{}; + + InputOutputIterIndexerT in_out_iter_indexer{inp_indexer, + noop_tmp_indexer}; + ReductionIndexerT reduction_indexer{red_nd, reduction_arg_offset, + reduction_shape_stride}; + + auto globalRange = + sycl::range<1>{iter_nelems * reduction_groups * wg}; + auto localRange = sycl::range<1>{wg}; + + if constexpr (can_use_reduce_over_group::value) + { + using KernelName = class search_over_group_temps_strided_krn< + argTy, resTy, ReductionOpT, IndexOpT, + InputOutputIterIndexerT, ReductionIndexerT, true, false>; + cgh.parallel_for( + sycl::nd_range<1>(globalRange, localRange), + SearchReduction( + arg_tp, partially_reduced_vals_tmp, nullptr, + partially_reduced_tmp, ReductionOpT(), identity_val, + IndexOpT(), idx_identity_val, in_out_iter_indexer, + reduction_indexer, reduction_nelems, iter_nelems, + preferrered_reductions_per_wi)); + } + else { + using SlmT = sycl::local_accessor; + SlmT local_memory = SlmT(localRange, cgh); + using KernelName = + class custom_search_over_group_temps_strided_krn< + argTy, resTy, ReductionOpT, IndexOpT, + InputOutputIterIndexerT, ReductionIndexerT, SlmT, true, + false>; + cgh.parallel_for( + sycl::nd_range<1>(globalRange, localRange), + CustomSearchReduction( + arg_tp, partially_reduced_vals_tmp, nullptr, + partially_reduced_tmp, ReductionOpT(), identity_val, + IndexOpT(), idx_identity_val, in_out_iter_indexer, + reduction_indexer, local_memory, reduction_nelems, + iter_nelems, preferrered_reductions_per_wi)); + } + }); + + size_t remaining_reduction_nelems = reduction_groups; + + resTy *temp_arg = partially_reduced_tmp; + resTy *temp2_arg = partially_reduced_tmp2; + + argTy *vals_temp_arg = partially_reduced_vals_tmp; + argTy *vals_temp2_arg = partially_reduced_vals_tmp2; + + sycl::event dependent_ev = first_reduction_ev; + + while (remaining_reduction_nelems > + preferrered_reductions_per_wi * max_wg) { + size_t reduction_groups_ = + (remaining_reduction_nelems + + preferrered_reductions_per_wi * wg - 1) / + (preferrered_reductions_per_wi * wg); + assert(reduction_groups_ > 1); + + // keep reducing + sycl::event partial_reduction_ev = exec_q.submit([&](sycl::handler + &cgh) { + cgh.depends_on(dependent_ev); + + using InputIndexerT = + dpctl::tensor::offset_utils::Strided1DIndexer; + using ResIndexerT = dpctl::tensor::offset_utils::NoOpIndexer; + using InputOutputIterIndexerT = + dpctl::tensor::offset_utils::TwoOffsets_CombinedIndexer< + InputIndexerT, ResIndexerT>; + using ReductionIndexerT = + dpctl::tensor::offset_utils::NoOpIndexer; + + InputIndexerT inp_indexer{ + 0, static_cast(iter_nelems), + static_cast(reduction_groups_)}; + ResIndexerT res_iter_indexer{}; + + InputOutputIterIndexerT in_out_iter_indexer{inp_indexer, + res_iter_indexer}; + ReductionIndexerT reduction_indexer{}; + + auto globalRange = + sycl::range<1>{iter_nelems * reduction_groups_ * wg}; + auto localRange = sycl::range<1>{wg}; + if constexpr (can_use_reduce_over_group::value) { + using KernelName = + class search_over_group_temps_strided_krn< + argTy, resTy, ReductionOpT, IndexOpT, + InputOutputIterIndexerT, ReductionIndexerT, false, + false>; + cgh.parallel_for( + sycl::nd_range<1>(globalRange, localRange), + SearchReduction( + vals_temp_arg, vals_temp2_arg, temp_arg, temp2_arg, + ReductionOpT(), identity_val, IndexOpT(), + idx_identity_val, in_out_iter_indexer, + reduction_indexer, remaining_reduction_nelems, + iter_nelems, preferrered_reductions_per_wi)); + } + else { + using SlmT = sycl::local_accessor; + SlmT local_memory = SlmT(localRange, cgh); + using KernelName = + class custom_search_over_group_temps_strided_krn< + argTy, resTy, ReductionOpT, IndexOpT, + InputOutputIterIndexerT, ReductionIndexerT, SlmT, + false, false>; + cgh.parallel_for( + sycl::nd_range<1>(globalRange, localRange), + CustomSearchReduction( + vals_temp_arg, vals_temp2_arg, temp_arg, temp2_arg, + ReductionOpT(), identity_val, IndexOpT(), + idx_identity_val, in_out_iter_indexer, + reduction_indexer, local_memory, + remaining_reduction_nelems, iter_nelems, + preferrered_reductions_per_wi)); + } + }); + + remaining_reduction_nelems = reduction_groups_; + std::swap(temp_arg, temp2_arg); + std::swap(vals_temp_arg, vals_temp2_arg); + dependent_ev = partial_reduction_ev; + } + + // final reduction to res + sycl::event final_reduction_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(dependent_ev); + + using InputIndexerT = dpctl::tensor::offset_utils::Strided1DIndexer; + using ResIndexerT = + dpctl::tensor::offset_utils::UnpackedStridedIndexer; + using InputOutputIterIndexerT = + dpctl::tensor::offset_utils::TwoOffsets_CombinedIndexer< + InputIndexerT, ResIndexerT>; + using ReductionIndexerT = dpctl::tensor::offset_utils::NoOpIndexer; + + InputIndexerT inp_indexer{ + 0, static_cast(iter_nelems), + static_cast(remaining_reduction_nelems)}; + ResIndexerT res_iter_indexer{iter_nd, iter_res_offset, + /* shape */ iter_shape_and_strides, + /* strides */ iter_shape_and_strides + + 2 * iter_nd}; + + InputOutputIterIndexerT in_out_iter_indexer{inp_indexer, + res_iter_indexer}; + ReductionIndexerT reduction_indexer{}; + + wg = max_wg; + reductions_per_wi = + std::max(1, (remaining_reduction_nelems + wg - 1) / wg); + + size_t reduction_groups = + (remaining_reduction_nelems + reductions_per_wi * wg - 1) / + (reductions_per_wi * wg); + assert(reduction_groups == 1); + + auto globalRange = + sycl::range<1>{iter_nelems * reduction_groups * wg}; + auto localRange = sycl::range<1>{wg}; + + if constexpr (can_use_reduce_over_group::value) + { + using KernelName = class search_over_group_temps_strided_krn< + argTy, resTy, ReductionOpT, IndexOpT, + InputOutputIterIndexerT, ReductionIndexerT, false, true>; + cgh.parallel_for( + sycl::nd_range<1>(globalRange, localRange), + SearchReduction( + vals_temp_arg, nullptr, temp_arg, res_tp, + ReductionOpT(), identity_val, IndexOpT(), + idx_identity_val, in_out_iter_indexer, + reduction_indexer, remaining_reduction_nelems, + iter_nelems, reductions_per_wi)); + } + else { + using SlmT = sycl::local_accessor; + SlmT local_memory = SlmT(localRange, cgh); + using KernelName = + class custom_search_over_group_temps_strided_krn< + argTy, resTy, ReductionOpT, IndexOpT, + InputOutputIterIndexerT, ReductionIndexerT, SlmT, false, + true>; + cgh.parallel_for( + sycl::nd_range<1>(globalRange, localRange), + CustomSearchReduction( + vals_temp_arg, nullptr, temp_arg, res_tp, + ReductionOpT(), identity_val, IndexOpT(), + idx_identity_val, in_out_iter_indexer, + reduction_indexer, local_memory, + remaining_reduction_nelems, iter_nelems, + reductions_per_wi)); + } + }); + + sycl::event cleanup_host_task_event = + exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(final_reduction_ev); + sycl::context ctx = exec_q.get_context(); + + cgh.host_task( + [ctx, partially_reduced_tmp, partially_reduced_vals_tmp] { + sycl::free(partially_reduced_tmp, ctx); + sycl::free(partially_reduced_vals_tmp, ctx); + }); + }); + + // FIXME: do not return host-task event + // Instead collect all host-tasks to a list + + return cleanup_host_task_event; + } +} + +typedef sycl::event (*search_contig_impl_fn_ptr)( + sycl::queue, + size_t, + size_t, + const char *, + char *, + py::ssize_t, + py::ssize_t, + py::ssize_t, + const std::vector &); + +template +sycl::event search_axis1_over_group_temps_contig_impl( + sycl::queue exec_q, + size_t iter_nelems, // number of reductions (num. of rows in a matrix + // when reducing over rows) + size_t reduction_nelems, // size of each reduction (length of rows, i.e. + // number of columns) + const char *arg_cp, + char *res_cp, + py::ssize_t iter_arg_offset, + py::ssize_t iter_res_offset, + py::ssize_t reduction_arg_offset, + const std::vector &depends) +{ + const argTy *arg_tp = reinterpret_cast(arg_cp); + resTy *res_tp = reinterpret_cast(res_cp); + + constexpr argTy identity_val = su_ns::Identity::value; + constexpr resTy idx_identity_val = su_ns::Identity::value; + + const sycl::device &d = exec_q.get_device(); + const auto &sg_sizes = d.get_info(); + size_t wg = choose_workgroup_size<4>(reduction_nelems, sg_sizes); + + constexpr size_t preferrered_reductions_per_wi = 8; + // max_max_wg prevents running out of resources on CPU + size_t max_wg = std::min( + size_t(2048), d.get_info()); + + size_t reductions_per_wi(preferrered_reductions_per_wi); + if (reduction_nelems <= preferrered_reductions_per_wi * max_wg) { + // reduction only requries 1 work-group, can output directly to res + sycl::event comp_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(depends); + + using InputIterIndexerT = + dpctl::tensor::offset_utils::Strided1DIndexer; + using NoOpIndexerT = dpctl::tensor::offset_utils::NoOpIndexer; + using InputOutputIterIndexerT = + dpctl::tensor::offset_utils::TwoOffsets_CombinedIndexer< + InputIterIndexerT, NoOpIndexerT>; + using ReductionIndexerT = NoOpIndexerT; + + InputOutputIterIndexerT in_out_iter_indexer{ + InputIterIndexerT{0, static_cast(iter_nelems), + static_cast(reduction_nelems)}, + NoOpIndexerT{}}; + ReductionIndexerT reduction_indexer{}; + + wg = max_wg; + reductions_per_wi = + std::max(1, (reduction_nelems + wg - 1) / wg); + + size_t reduction_groups = + (reduction_nelems + reductions_per_wi * wg - 1) / + (reductions_per_wi * wg); + assert(reduction_groups == 1); + + auto globalRange = + sycl::range<1>{iter_nelems * reduction_groups * wg}; + auto localRange = sycl::range<1>{wg}; + + if constexpr (can_use_reduce_over_group::value) + { + using KernelName = class single_search_axis1_temps_contig_krn< + argTy, resTy, ReductionOpT, IndexOpT, + InputOutputIterIndexerT, ReductionIndexerT, true, true>; + cgh.parallel_for( + sycl::nd_range<1>(globalRange, localRange), + SearchReduction( + arg_tp, nullptr, nullptr, res_tp, ReductionOpT(), + identity_val, IndexOpT(), idx_identity_val, + in_out_iter_indexer, reduction_indexer, + reduction_nelems, iter_nelems, reductions_per_wi)); + } + else { + using SlmT = sycl::local_accessor; + SlmT local_memory = SlmT(localRange, cgh); + using KernelName = + class single_custom_search_axis1_temps_contig_krn< + argTy, resTy, ReductionOpT, IndexOpT, + InputOutputIterIndexerT, ReductionIndexerT, SlmT, true, + true>; + cgh.parallel_for( + sycl::nd_range<1>(globalRange, localRange), + CustomSearchReduction( + arg_tp, nullptr, nullptr, res_tp, ReductionOpT(), + identity_val, IndexOpT(), idx_identity_val, + in_out_iter_indexer, reduction_indexer, local_memory, + reduction_nelems, iter_nelems, reductions_per_wi)); + } + }); + return comp_ev; + } + else { + // more than one work-groups is needed, requires a temporary + size_t reduction_groups = + (reduction_nelems + preferrered_reductions_per_wi * wg - 1) / + (preferrered_reductions_per_wi * wg); + assert(reduction_groups > 1); + + size_t second_iter_reduction_groups_ = + (reduction_groups + preferrered_reductions_per_wi * wg - 1) / + (preferrered_reductions_per_wi * wg); + + resTy *partially_reduced_tmp = sycl::malloc_device( + iter_nelems * (reduction_groups + second_iter_reduction_groups_), + exec_q); + resTy *partially_reduced_tmp2 = nullptr; + + if (partially_reduced_tmp == nullptr) { + throw std::runtime_error("Unable to allocate device_memory"); + } + else { + partially_reduced_tmp2 = + partially_reduced_tmp + reduction_groups * iter_nelems; + } + + argTy *partially_reduced_vals_tmp = sycl::malloc_device( + iter_nelems * (reduction_groups + second_iter_reduction_groups_), + exec_q); + argTy *partially_reduced_vals_tmp2 = nullptr; + + if (partially_reduced_vals_tmp == nullptr) { + throw std::runtime_error("Unable to allocate device_memory"); + } + else { + partially_reduced_vals_tmp2 = + partially_reduced_vals_tmp + reduction_groups * iter_nelems; + } + + sycl::event first_reduction_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(depends); + + using InputIterIndexerT = + dpctl::tensor::offset_utils::Strided1DIndexer; + using NoOpIndexerT = dpctl::tensor::offset_utils::NoOpIndexer; + using InputOutputIterIndexerT = + dpctl::tensor::offset_utils::TwoOffsets_CombinedIndexer< + InputIterIndexerT, NoOpIndexerT>; + using ReductionIndexerT = NoOpIndexerT; + + InputOutputIterIndexerT in_out_iter_indexer{ + InputIterIndexerT{0, static_cast(iter_nelems), + static_cast(reduction_nelems)}, + NoOpIndexerT{}}; + ReductionIndexerT reduction_indexer{}; + + auto globalRange = + sycl::range<1>{iter_nelems * reduction_groups * wg}; + auto localRange = sycl::range<1>{wg}; + + if constexpr (can_use_reduce_over_group::value) + { + using KernelName = class first_search_axis1_temps_contig_krn< + argTy, resTy, ReductionOpT, IndexOpT, + InputOutputIterIndexerT, ReductionIndexerT, true, false>; + cgh.parallel_for( + sycl::nd_range<1>(globalRange, localRange), + SearchReduction( + arg_tp, partially_reduced_vals_tmp, nullptr, + partially_reduced_tmp, ReductionOpT(), identity_val, + IndexOpT(), idx_identity_val, in_out_iter_indexer, + reduction_indexer, reduction_nelems, iter_nelems, + preferrered_reductions_per_wi)); + } + else { + using SlmT = sycl::local_accessor; + SlmT local_memory = SlmT(localRange, cgh); + using KernelName = + class first_custom_search_axis1_temps_contig_krn< + argTy, resTy, ReductionOpT, IndexOpT, + InputOutputIterIndexerT, ReductionIndexerT, SlmT, true, + false>; + cgh.parallel_for( + sycl::nd_range<1>(globalRange, localRange), + CustomSearchReduction( + arg_tp, partially_reduced_vals_tmp, nullptr, + partially_reduced_tmp, ReductionOpT(), identity_val, + IndexOpT(), idx_identity_val, in_out_iter_indexer, + reduction_indexer, local_memory, reduction_nelems, + iter_nelems, preferrered_reductions_per_wi)); + } + }); + + size_t remaining_reduction_nelems = reduction_groups; + + resTy *temp_arg = partially_reduced_tmp; + resTy *temp2_arg = partially_reduced_tmp2; + + argTy *vals_temp_arg = partially_reduced_vals_tmp; + argTy *vals_temp2_arg = partially_reduced_vals_tmp2; + + sycl::event dependent_ev = first_reduction_ev; + + while (remaining_reduction_nelems > + preferrered_reductions_per_wi * max_wg) { + size_t reduction_groups_ = + (remaining_reduction_nelems + + preferrered_reductions_per_wi * wg - 1) / + (preferrered_reductions_per_wi * wg); + assert(reduction_groups_ > 1); - // work-items operates over input with indices - // inp_data_id = reduction_batch_id * wg * reductions_per_wi + m * wg - // + reduction_lid - // for 0 <= m < reductions_per_wi + // keep reducing + sycl::event partial_reduction_ev = exec_q.submit([&](sycl::handler + &cgh) { + cgh.depends_on(dependent_ev); - auto inp_out_iter_offsets_ = inp_out_iter_indexer_(iter_gid); - const auto &inp_iter_offset = inp_out_iter_offsets_.get_first_offset(); - const auto &out_iter_offset = inp_out_iter_offsets_.get_second_offset(); + using InputIndexerT = + dpctl::tensor::offset_utils::Strided1DIndexer; + using ResIndexerT = dpctl::tensor::offset_utils::NoOpIndexer; + using InputOutputIterIndexerT = + dpctl::tensor::offset_utils::TwoOffsets_CombinedIndexer< + InputIndexerT, ResIndexerT>; + using ReductionIndexerT = + dpctl::tensor::offset_utils::NoOpIndexer; - argT local_red_val(identity_); - outT local_idx(idx_identity_); - size_t arg_reduce_gid0 = - reduction_lid + reduction_batch_id * wg * reductions_per_wi; - for (size_t m = 0; m < reductions_per_wi; ++m) { - size_t arg_reduce_gid = arg_reduce_gid0 + m * wg; + InputIndexerT inp_indexer{ + 0, static_cast(iter_nelems), + static_cast(reduction_groups_)}; + ResIndexerT res_iter_indexer{}; - if (arg_reduce_gid < reduction_max_gid_) { - auto inp_reduction_offset = - inp_reduced_dims_indexer_(arg_reduce_gid); - auto inp_offset = inp_iter_offset + inp_reduction_offset; + InputOutputIterIndexerT in_out_iter_indexer{inp_indexer, + res_iter_indexer}; + ReductionIndexerT reduction_indexer{}; - argT val = inp_[inp_offset]; - if (val == local_red_val) { - if constexpr (!First) { - local_idx = - idx_reduction_op_(local_idx, inds_[inp_offset]); - } - else { - local_idx = idx_reduction_op_( - local_idx, static_cast(arg_reduce_gid)); - } + auto globalRange = + sycl::range<1>{iter_nelems * reduction_groups_ * wg}; + auto localRange = sycl::range<1>{wg}; + if constexpr (can_use_reduce_over_group::value) { + using KernelName = + class middle_search_axis1_temps_contig_krn< + argTy, resTy, ReductionOpT, IndexOpT, + InputOutputIterIndexerT, ReductionIndexerT, false, + false>; + cgh.parallel_for( + sycl::nd_range<1>(globalRange, localRange), + SearchReduction( + vals_temp_arg, vals_temp2_arg, temp_arg, temp2_arg, + ReductionOpT(), identity_val, IndexOpT(), + idx_identity_val, in_out_iter_indexer, + reduction_indexer, remaining_reduction_nelems, + iter_nelems, preferrered_reductions_per_wi)); } else { - if constexpr (su_ns::IsMinimum::value) { - using dpctl::tensor::type_utils::is_complex; - if constexpr (is_complex::value) { - using dpctl::tensor::math_utils::less_complex; - // less_complex always returns false for NaNs, so - // check - if (less_complex(val, local_red_val) || - std::isnan(std::real(val)) || - std::isnan(std::imag(val))) - { - local_red_val = val; - if constexpr (!First) { - local_idx = inds_[inp_offset]; - } - else { - local_idx = - static_cast(arg_reduce_gid); - } - } - } - else if constexpr (std::is_floating_point_v) { - if (val < local_red_val || std::isnan(val)) { - local_red_val = val; - if constexpr (!First) { - local_idx = inds_[inp_offset]; - } - else { - local_idx = - static_cast(arg_reduce_gid); - } - } - } - else { - if (val < local_red_val) { - local_red_val = val; - if constexpr (!First) { - local_idx = inds_[inp_offset]; - } - else { - local_idx = - static_cast(arg_reduce_gid); - } - } - } - } - else if constexpr (su_ns::IsMaximum::value) { - using dpctl::tensor::type_utils::is_complex; - if constexpr (is_complex::value) { - using dpctl::tensor::math_utils::greater_complex; - if (greater_complex(val, local_red_val) || - std::isnan(std::real(val)) || - std::isnan(std::imag(val))) - { - local_red_val = val; - if constexpr (!First) { - local_idx = inds_[inp_offset]; - } - else { - local_idx = - static_cast(arg_reduce_gid); - } - } - } - else if constexpr (std::is_floating_point_v) { - if (val > local_red_val || std::isnan(val)) { - local_red_val = val; - if constexpr (!First) { - local_idx = inds_[inp_offset]; - } - else { - local_idx = - static_cast(arg_reduce_gid); - } - } - } - else { - if (val > local_red_val) { - local_red_val = val; - if constexpr (!First) { - local_idx = inds_[inp_offset]; - } - else { - local_idx = - static_cast(arg_reduce_gid); - } - } - } - } + using SlmT = sycl::local_accessor; + SlmT local_memory = SlmT(localRange, cgh); + using KernelName = + class middle_custom_search_axis1_temps_contig_krn< + argTy, resTy, ReductionOpT, IndexOpT, + InputOutputIterIndexerT, ReductionIndexerT, SlmT, + false, false>; + cgh.parallel_for( + sycl::nd_range<1>(globalRange, localRange), + CustomSearchReduction( + vals_temp_arg, vals_temp2_arg, temp_arg, temp2_arg, + ReductionOpT(), identity_val, IndexOpT(), + idx_identity_val, in_out_iter_indexer, + reduction_indexer, local_memory, + remaining_reduction_nelems, iter_nelems, + preferrered_reductions_per_wi)); } - } + }); + + remaining_reduction_nelems = reduction_groups_; + std::swap(temp_arg, temp2_arg); + std::swap(vals_temp_arg, vals_temp2_arg); + dependent_ev = partial_reduction_ev; } - auto work_group = it.get_group(); - // This only works if reduction_op_ is from small set of operators - argT red_val_over_wg = su_ns::custom_reduce_over_group( - work_group, local_mem_, local_red_val, reduction_op_); + // final reduction to res + sycl::event final_reduction_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(dependent_ev); - using dpctl::tensor::type_utils::is_complex; - if constexpr (is_complex::value) { - // equality does not hold for NaNs, so check here - local_idx = (red_val_over_wg == local_red_val || - std::isnan(std::real(local_red_val)) || - std::isnan(std::imag(local_red_val))) - ? local_idx - : idx_identity_; - } - else if constexpr (std::is_floating_point_v) { - // equality does not hold for NaNs, so check here - local_idx = - (red_val_over_wg == local_red_val || std::isnan(local_red_val)) - ? local_idx - : idx_identity_; - } - else { - local_idx = - red_val_over_wg == local_red_val ? local_idx : idx_identity_; - } - outT idx_over_wg = sycl::reduce_over_group( - work_group, local_idx, idx_identity_, idx_reduction_op_); - if (work_group.leader()) { - // each group writes to a different memory location - if constexpr (!Last) { - // if not the final reduction, write value corresponding to - // an index to a temporary - vals_[out_iter_offset * n_reduction_groups + - reduction_batch_id] = red_val_over_wg; + using InputIndexerT = dpctl::tensor::offset_utils::Strided1DIndexer; + using ResIndexerT = dpctl::tensor::offset_utils::NoOpIndexer; + using InputOutputIterIndexerT = + dpctl::tensor::offset_utils::TwoOffsets_CombinedIndexer< + InputIndexerT, ResIndexerT>; + using ReductionIndexerT = dpctl::tensor::offset_utils::NoOpIndexer; + + InputIndexerT inp_indexer{ + 0, static_cast(iter_nelems), + static_cast(remaining_reduction_nelems)}; + ResIndexerT res_iter_indexer{}; + + InputOutputIterIndexerT in_out_iter_indexer{inp_indexer, + res_iter_indexer}; + ReductionIndexerT reduction_indexer{}; + + wg = max_wg; + reductions_per_wi = + std::max(1, (remaining_reduction_nelems + wg - 1) / wg); + + size_t reduction_groups = + (remaining_reduction_nelems + reductions_per_wi * wg - 1) / + (reductions_per_wi * wg); + assert(reduction_groups == 1); + + auto globalRange = + sycl::range<1>{iter_nelems * reduction_groups * wg}; + auto localRange = sycl::range<1>{wg}; + + if constexpr (can_use_reduce_over_group::value) + { + using KernelName = class final_search_axis1_temps_contig_krn< + argTy, resTy, ReductionOpT, IndexOpT, + InputOutputIterIndexerT, ReductionIndexerT, false, true>; + cgh.parallel_for( + sycl::nd_range<1>(globalRange, localRange), + SearchReduction( + vals_temp_arg, nullptr, temp_arg, res_tp, + ReductionOpT(), identity_val, IndexOpT(), + idx_identity_val, in_out_iter_indexer, + reduction_indexer, remaining_reduction_nelems, + iter_nelems, reductions_per_wi)); } - out_[out_iter_offset * n_reduction_groups + reduction_batch_id] = - idx_over_wg; - } - } -}; + else { + using SlmT = sycl::local_accessor; + SlmT local_memory = SlmT(localRange, cgh); + using KernelName = + class final_custom_search_axis1_temps_contig_krn< + argTy, resTy, ReductionOpT, IndexOpT, + InputOutputIterIndexerT, ReductionIndexerT, SlmT, false, + true>; + cgh.parallel_for( + sycl::nd_range<1>(globalRange, localRange), + CustomSearchReduction( + vals_temp_arg, nullptr, temp_arg, res_tp, + ReductionOpT(), identity_val, IndexOpT(), + idx_identity_val, in_out_iter_indexer, + reduction_indexer, local_memory, + remaining_reduction_nelems, iter_nelems, + reductions_per_wi)); + } + }); -typedef sycl::event (*search_reduction_strided_impl_fn_ptr)( - sycl::queue, - size_t, - size_t, - const char *, - char *, - int, - const py::ssize_t *, - py::ssize_t, - py::ssize_t, - int, - const py::ssize_t *, - py::ssize_t, - const std::vector &); + sycl::event cleanup_host_task_event = + exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(final_reduction_ev); + sycl::context ctx = exec_q.get_context(); -template -class search_reduction_over_group_temps_krn; + cgh.host_task( + [ctx, partially_reduced_tmp, partially_reduced_vals_tmp] { + sycl::free(partially_reduced_tmp, ctx); + sycl::free(partially_reduced_vals_tmp, ctx); + }); + }); -template -class search_custom_reduction_over_group_temps_krn; + // FIXME: do not return host-task event + // Instead collect all host-tasks to a list -using dpctl::tensor::sycl_utils::choose_workgroup_size; + return cleanup_host_task_event; + } +} template -sycl::event search_reduction_over_group_temps_strided_impl( +sycl::event search_axis0_over_group_temps_contig_impl( sycl::queue exec_q, size_t iter_nelems, // number of reductions (num. of rows in a matrix // when reducing over rows) @@ -2525,12 +4630,8 @@ sycl::event search_reduction_over_group_temps_strided_impl( // number of columns) const char *arg_cp, char *res_cp, - int iter_nd, - const py::ssize_t *iter_shape_and_strides, py::ssize_t iter_arg_offset, py::ssize_t iter_res_offset, - int red_nd, - const py::ssize_t *reduction_shape_stride, py::ssize_t reduction_arg_offset, const std::vector &depends) { @@ -2544,7 +4645,7 @@ sycl::event search_reduction_over_group_temps_strided_impl( const auto &sg_sizes = d.get_info(); size_t wg = choose_workgroup_size<4>(reduction_nelems, sg_sizes); - constexpr size_t preferrered_reductions_per_wi = 4; + constexpr size_t preferrered_reductions_per_wi = 8; // max_max_wg prevents running out of resources on CPU size_t max_wg = std::min( size_t(2048), d.get_info()); @@ -2555,16 +4656,20 @@ sycl::event search_reduction_over_group_temps_strided_impl( sycl::event comp_ev = exec_q.submit([&](sycl::handler &cgh) { cgh.depends_on(depends); + using NoOpIndexerT = dpctl::tensor::offset_utils::NoOpIndexer; + using ColsIndexerT = dpctl::tensor::offset_utils::Strided1DIndexer; using InputOutputIterIndexerT = - dpctl::tensor::offset_utils::TwoOffsets_StridedIndexer; - using ReductionIndexerT = - dpctl::tensor::offset_utils::StridedIndexer; + dpctl::tensor::offset_utils::TwoOffsets_CombinedIndexer< + NoOpIndexerT, NoOpIndexerT>; + using ReductionIndexerT = ColsIndexerT; - InputOutputIterIndexerT in_out_iter_indexer{ - iter_nd, iter_arg_offset, iter_res_offset, - iter_shape_and_strides}; - ReductionIndexerT reduction_indexer{red_nd, reduction_arg_offset, - reduction_shape_stride}; + NoOpIndexerT columns_indexer{}; + NoOpIndexerT result_indexer{}; + InputOutputIterIndexerT in_out_iter_indexer{columns_indexer, + result_indexer}; + ReductionIndexerT reduction_indexer{ + 0, /* size */ static_cast(reduction_nelems), + /* step */ static_cast(iter_nelems)}; wg = max_wg; reductions_per_wi = @@ -2581,7 +4686,7 @@ sycl::event search_reduction_over_group_temps_strided_impl( if constexpr (can_use_reduce_over_group::value) { - using KernelName = class search_reduction_over_group_temps_krn< + using KernelName = class single_search_axis0_temps_contig_krn< argTy, resTy, ReductionOpT, IndexOpT, InputOutputIterIndexerT, ReductionIndexerT, true, true>; cgh.parallel_for( @@ -2598,7 +4703,7 @@ sycl::event search_reduction_over_group_temps_strided_impl( using SlmT = sycl::local_accessor; SlmT local_memory = SlmT(localRange, cgh); using KernelName = - class search_custom_reduction_over_group_temps_krn< + class single_custom_search_axis0_temps_contig_krn< argTy, resTy, ReductionOpT, IndexOpT, InputOutputIterIndexerT, ReductionIndexerT, SlmT, true, true>; @@ -2655,25 +4760,20 @@ sycl::event search_reduction_over_group_temps_strided_impl( sycl::event first_reduction_ev = exec_q.submit([&](sycl::handler &cgh) { cgh.depends_on(depends); - using InputIndexerT = dpctl::tensor::offset_utils::StridedIndexer; - using ResIndexerT = dpctl::tensor::offset_utils::NoOpIndexer; + using NoOpIndexerT = dpctl::tensor::offset_utils::NoOpIndexer; + using ColsIndexerT = dpctl::tensor::offset_utils::Strided1DIndexer; using InputOutputIterIndexerT = dpctl::tensor::offset_utils::TwoOffsets_CombinedIndexer< - InputIndexerT, ResIndexerT>; - using ReductionIndexerT = - dpctl::tensor::offset_utils::StridedIndexer; - - // Only 2*iter_nd entries describing shape and strides of iterated - // dimensions of input array from iter_shape_and_strides are going - // to be accessed by inp_indexer - InputIndexerT inp_indexer(iter_nd, iter_arg_offset, - iter_shape_and_strides); - ResIndexerT noop_tmp_indexer{}; + NoOpIndexerT, NoOpIndexerT>; + using ReductionIndexerT = ColsIndexerT; - InputOutputIterIndexerT in_out_iter_indexer{inp_indexer, - noop_tmp_indexer}; - ReductionIndexerT reduction_indexer{red_nd, reduction_arg_offset, - reduction_shape_stride}; + NoOpIndexerT columns_indexer{}; + NoOpIndexerT result_indexer{}; + InputOutputIterIndexerT in_out_iter_indexer{columns_indexer, + result_indexer}; + ReductionIndexerT reduction_indexer{ + 0, /* size */ static_cast(reduction_nelems), + /* step */ static_cast(iter_nelems)}; auto globalRange = sycl::range<1>{iter_nelems * reduction_groups * wg}; @@ -2681,7 +4781,7 @@ sycl::event search_reduction_over_group_temps_strided_impl( if constexpr (can_use_reduce_over_group::value) { - using KernelName = class search_reduction_over_group_temps_krn< + using KernelName = class first_search_axis0_temps_contig_krn< argTy, resTy, ReductionOpT, IndexOpT, InputOutputIterIndexerT, ReductionIndexerT, true, false>; cgh.parallel_for( @@ -2699,7 +4799,7 @@ sycl::event search_reduction_over_group_temps_strided_impl( using SlmT = sycl::local_accessor; SlmT local_memory = SlmT(localRange, cgh); using KernelName = - class search_custom_reduction_over_group_temps_krn< + class first_custom_search_axis0_temps_contig_krn< argTy, resTy, ReductionOpT, IndexOpT, InputOutputIterIndexerT, ReductionIndexerT, SlmT, true, false>; @@ -2763,7 +4863,7 @@ sycl::event search_reduction_over_group_temps_strided_impl( if constexpr (can_use_reduce_over_group::value) { using KernelName = - class search_reduction_over_group_temps_krn< + class middle_search_axis0_temps_contig_krn< argTy, resTy, ReductionOpT, IndexOpT, InputOutputIterIndexerT, ReductionIndexerT, false, false>; @@ -2782,7 +4882,7 @@ sycl::event search_reduction_over_group_temps_strided_impl( using SlmT = sycl::local_accessor; SlmT local_memory = SlmT(localRange, cgh); using KernelName = - class search_custom_reduction_over_group_temps_krn< + class middle_custom_search_axis0_temps_contig_krn< argTy, resTy, ReductionOpT, IndexOpT, InputOutputIterIndexerT, ReductionIndexerT, SlmT, false, false>; @@ -2812,8 +4912,7 @@ sycl::event search_reduction_over_group_temps_strided_impl( cgh.depends_on(dependent_ev); using InputIndexerT = dpctl::tensor::offset_utils::Strided1DIndexer; - using ResIndexerT = - dpctl::tensor::offset_utils::UnpackedStridedIndexer; + using ResIndexerT = dpctl::tensor::offset_utils::NoOpIndexer; using InputOutputIterIndexerT = dpctl::tensor::offset_utils::TwoOffsets_CombinedIndexer< InputIndexerT, ResIndexerT>; @@ -2822,10 +4921,7 @@ sycl::event search_reduction_over_group_temps_strided_impl( InputIndexerT inp_indexer{ 0, static_cast(iter_nelems), static_cast(remaining_reduction_nelems)}; - ResIndexerT res_iter_indexer{iter_nd, iter_res_offset, - /* shape */ iter_shape_and_strides, - /*s trides */ iter_shape_and_strides + - 2 * iter_nd}; + ResIndexerT res_iter_indexer{}; InputOutputIterIndexerT in_out_iter_indexer{inp_indexer, res_iter_indexer}; @@ -2846,7 +4942,7 @@ sycl::event search_reduction_over_group_temps_strided_impl( if constexpr (can_use_reduce_over_group::value) { - using KernelName = class search_reduction_over_group_temps_krn< + using KernelName = class final_search_axis0_temps_contig_krn< argTy, resTy, ReductionOpT, IndexOpT, InputOutputIterIndexerT, ReductionIndexerT, false, true>; cgh.parallel_for( @@ -2864,7 +4960,7 @@ sycl::event search_reduction_over_group_temps_strided_impl( using SlmT = sycl::local_accessor; SlmT local_memory = SlmT(localRange, cgh); using KernelName = - class search_custom_reduction_over_group_temps_krn< + class final_custom_search_axis0_temps_contig_krn< argTy, resTy, ReductionOpT, IndexOpT, InputOutputIterIndexerT, ReductionIndexerT, SlmT, false, true>; @@ -2971,7 +5067,75 @@ struct ArgmaxOverAxisTempsStridedFactory // op for indices using IndexOpT = sycl::minimum; return dpctl::tensor::kernels:: - search_reduction_over_group_temps_strided_impl< + search_over_group_temps_strided_impl< + srcTy, dstTy, ReductionOpT, IndexOpT>; + } + else { + // op for values + using ReductionOpT = su_ns::Maximum; + // op for indices + using IndexOpT = sycl::minimum; + return dpctl::tensor::kernels:: + search_over_group_temps_strided_impl< + srcTy, dstTy, ReductionOpT, IndexOpT>; + } + } + else { + return nullptr; + } + } +}; + +template +struct ArgmaxOverAxis1TempsContigFactory +{ + fnT get() const + { + if constexpr (TypePairSupportDataForSearchReductionTemps< + srcTy, dstTy>::is_defined) + { + if constexpr (std::is_integral_v && + !std::is_same_v) { + // op for values + using ReductionOpT = sycl::maximum; + // op for indices + using IndexOpT = sycl::minimum; + return dpctl::tensor::kernels:: + search_axis1_over_group_temps_contig_impl< + srcTy, dstTy, ReductionOpT, IndexOpT>; + } + else { + // op for values + using ReductionOpT = su_ns::Maximum; + // op for indices + using IndexOpT = sycl::minimum; + return dpctl::tensor::kernels:: + search_axis1_over_group_temps_contig_impl< + srcTy, dstTy, ReductionOpT, IndexOpT>; + } + } + else { + return nullptr; + } + } +}; + +template +struct ArgmaxOverAxis0TempsContigFactory +{ + fnT get() const + { + if constexpr (TypePairSupportDataForSearchReductionTemps< + srcTy, dstTy>::is_defined) + { + if constexpr (std::is_integral_v && + !std::is_same_v) { + // op for values + using ReductionOpT = sycl::maximum; + // op for indices + using IndexOpT = sycl::minimum; + return dpctl::tensor::kernels:: + search_axis0_over_group_temps_contig_impl< srcTy, dstTy, ReductionOpT, IndexOpT>; } else { @@ -2980,7 +5144,7 @@ struct ArgmaxOverAxisTempsStridedFactory // op for indices using IndexOpT = sycl::minimum; return dpctl::tensor::kernels:: - search_reduction_over_group_temps_strided_impl< + search_axis0_over_group_temps_contig_impl< srcTy, dstTy, ReductionOpT, IndexOpT>; } } @@ -3005,7 +5169,75 @@ struct ArgminOverAxisTempsStridedFactory // op for indices using IndexOpT = sycl::minimum; return dpctl::tensor::kernels:: - search_reduction_over_group_temps_strided_impl< + search_over_group_temps_strided_impl< + srcTy, dstTy, ReductionOpT, IndexOpT>; + } + else { + // op for values + using ReductionOpT = su_ns::Minimum; + // op for indices + using IndexOpT = sycl::minimum; + return dpctl::tensor::kernels:: + search_over_group_temps_strided_impl< + srcTy, dstTy, ReductionOpT, IndexOpT>; + } + } + else { + return nullptr; + } + } +}; + +template +struct ArgminOverAxis1TempsContigFactory +{ + fnT get() const + { + if constexpr (TypePairSupportDataForSearchReductionTemps< + srcTy, dstTy>::is_defined) + { + if constexpr (std::is_integral_v && + !std::is_same_v) { + // op for values + using ReductionOpT = sycl::minimum; + // op for indices + using IndexOpT = sycl::minimum; + return dpctl::tensor::kernels:: + search_axis1_over_group_temps_contig_impl< + srcTy, dstTy, ReductionOpT, IndexOpT>; + } + else { + // op for values + using ReductionOpT = su_ns::Minimum; + // op for indices + using IndexOpT = sycl::minimum; + return dpctl::tensor::kernels:: + search_axis1_over_group_temps_contig_impl< + srcTy, dstTy, ReductionOpT, IndexOpT>; + } + } + else { + return nullptr; + } + } +}; + +template +struct ArgminOverAxis0TempsContigFactory +{ + fnT get() const + { + if constexpr (TypePairSupportDataForSearchReductionTemps< + srcTy, dstTy>::is_defined) + { + if constexpr (std::is_integral_v && + !std::is_same_v) { + // op for values + using ReductionOpT = sycl::minimum; + // op for indices + using IndexOpT = sycl::minimum; + return dpctl::tensor::kernels:: + search_axis0_over_group_temps_contig_impl< srcTy, dstTy, ReductionOpT, IndexOpT>; } else { @@ -3014,7 +5246,7 @@ struct ArgminOverAxisTempsStridedFactory // op for indices using IndexOpT = sycl::minimum; return dpctl::tensor::kernels:: - search_reduction_over_group_temps_strided_impl< + search_axis0_over_group_temps_contig_impl< srcTy, dstTy, ReductionOpT, IndexOpT>; } } diff --git a/dpctl/tensor/libtensor/include/utils/math_utils.hpp b/dpctl/tensor/libtensor/include/utils/math_utils.hpp index d724e03e35..120a14d536 100644 --- a/dpctl/tensor/libtensor/include/utils/math_utils.hpp +++ b/dpctl/tensor/libtensor/include/utils/math_utils.hpp @@ -115,6 +115,26 @@ template T min_complex(const T &x1, const T &x2) return (std::isnan(real1) || isnan_imag1 || lt) ? x1 : x2; } +template T logaddexp(T x, T y) +{ + if (x == y) { // handle signed infinities + const T log2 = std::log(T(2)); + return x + log2; + } + else { + const T tmp = x - y; + if (tmp > 0) { + return x + std::log1p(std::exp(-tmp)); + } + else if (tmp <= 0) { + return y + std::log1p(std::exp(tmp)); + } + else { + return std::numeric_limits::quiet_NaN(); + } + } +} + } // namespace math_utils } // namespace tensor } // namespace dpctl diff --git a/dpctl/tensor/libtensor/include/utils/sycl_utils.hpp b/dpctl/tensor/libtensor/include/utils/sycl_utils.hpp index 0d4240c516..c0165b0ecc 100644 --- a/dpctl/tensor/libtensor/include/utils/sycl_utils.hpp +++ b/dpctl/tensor/libtensor/include/utils/sycl_utils.hpp @@ -286,6 +286,46 @@ struct GetIdentity::value>> static constexpr T value = static_cast(1); }; +// LogSumExp + +template struct LogSumExp +{ + T operator()(const T &x, const T &y) const + { + using dpctl::tensor::math_utils::logaddexp; + return logaddexp(x, y); + } +}; + +template +using IsLogSumExp = std::bool_constant>>; + +// only defined for types with infinity +template +struct GetIdentity::value>> +{ + static constexpr T value = -std::numeric_limits::infinity(); +}; + +// Hypot + +template struct Hypot +{ + T operator()(const T &x, const T &y) const + { + return sycl::hypot(x, y); + } +}; + +template +using IsHypot = std::bool_constant>>; + +template +struct GetIdentity::value>> +{ + static constexpr T value = 0; +}; + // Identity template struct Identity diff --git a/dpctl/tensor/libtensor/source/reduction_over_axis.cpp b/dpctl/tensor/libtensor/source/reduction_over_axis.cpp deleted file mode 100644 index c67fcd5ba3..0000000000 --- a/dpctl/tensor/libtensor/source/reduction_over_axis.cpp +++ /dev/null @@ -1,514 +0,0 @@ -//===-- ------------ Implementation of _tensor_impl module ----*-C++-*-/===// -// -// Data Parallel Control (dpctl) -// -// Copyright 2020-2023 Intel Corporation -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// -//===--------------------------------------------------------------------===// -/// -/// \file -/// This file defines functions of dpctl.tensor._tensor_impl extensions -//===--------------------------------------------------------------------===// - -#include -#include -#include -#include - -#include -#include -#include - -#include "dpctl4pybind11.hpp" -#include "kernels/reductions.hpp" -#include "reduction_over_axis.hpp" -#include "simplify_iteration_space.hpp" -#include "utils/type_dispatch.hpp" - -namespace dpctl -{ -namespace tensor -{ -namespace py_internal -{ - -namespace td_ns = dpctl::tensor::type_dispatch; -// Max -namespace impl -{ - -using dpctl::tensor::kernels::reduction_strided_impl_fn_ptr; -static reduction_strided_impl_fn_ptr - max_over_axis_strided_atomic_dispatch_table[td_ns::num_types] - [td_ns::num_types]; -static reduction_strided_impl_fn_ptr - max_over_axis_strided_temps_dispatch_table[td_ns::num_types] - [td_ns::num_types]; - -using dpctl::tensor::kernels::reduction_contig_impl_fn_ptr; -static reduction_contig_impl_fn_ptr - max_over_axis1_contig_atomic_dispatch_table[td_ns::num_types] - [td_ns::num_types]; -static reduction_contig_impl_fn_ptr - max_over_axis0_contig_atomic_dispatch_table[td_ns::num_types] - [td_ns::num_types]; - -void populate_max_over_axis_dispatch_tables(void) -{ - using dpctl::tensor::kernels::reduction_contig_impl_fn_ptr; - using dpctl::tensor::kernels::reduction_strided_impl_fn_ptr; - using td_ns::DispatchTableBuilder; - - using dpctl::tensor::kernels::MaxOverAxisAtomicStridedFactory; - DispatchTableBuilder - dtb1; - dtb1.populate_dispatch_table(max_over_axis_strided_atomic_dispatch_table); - - using dpctl::tensor::kernels::MaxOverAxisTempsStridedFactory; - DispatchTableBuilder - dtb2; - dtb2.populate_dispatch_table(max_over_axis_strided_temps_dispatch_table); - - using dpctl::tensor::kernels::MaxOverAxis1AtomicContigFactory; - DispatchTableBuilder - dtb3; - dtb3.populate_dispatch_table(max_over_axis1_contig_atomic_dispatch_table); - - using dpctl::tensor::kernels::MaxOverAxis0AtomicContigFactory; - DispatchTableBuilder - dtb4; - dtb4.populate_dispatch_table(max_over_axis0_contig_atomic_dispatch_table); -} - -} // namespace impl - -// Min -namespace impl -{ - -using dpctl::tensor::kernels::reduction_strided_impl_fn_ptr; -static reduction_strided_impl_fn_ptr - min_over_axis_strided_atomic_dispatch_table[td_ns::num_types] - [td_ns::num_types]; -static reduction_strided_impl_fn_ptr - min_over_axis_strided_temps_dispatch_table[td_ns::num_types] - [td_ns::num_types]; - -using dpctl::tensor::kernels::reduction_contig_impl_fn_ptr; -static reduction_contig_impl_fn_ptr - min_over_axis1_contig_atomic_dispatch_table[td_ns::num_types] - [td_ns::num_types]; -static reduction_contig_impl_fn_ptr - min_over_axis0_contig_atomic_dispatch_table[td_ns::num_types] - [td_ns::num_types]; - -void populate_min_over_axis_dispatch_tables(void) -{ - using dpctl::tensor::kernels::reduction_contig_impl_fn_ptr; - using dpctl::tensor::kernels::reduction_strided_impl_fn_ptr; - using td_ns::DispatchTableBuilder; - - using dpctl::tensor::kernels::MinOverAxisAtomicStridedFactory; - DispatchTableBuilder - dtb1; - dtb1.populate_dispatch_table(min_over_axis_strided_atomic_dispatch_table); - - using dpctl::tensor::kernels::MinOverAxisTempsStridedFactory; - DispatchTableBuilder - dtb2; - dtb2.populate_dispatch_table(min_over_axis_strided_temps_dispatch_table); - - using dpctl::tensor::kernels::MinOverAxis1AtomicContigFactory; - DispatchTableBuilder - dtb3; - dtb3.populate_dispatch_table(min_over_axis1_contig_atomic_dispatch_table); - - using dpctl::tensor::kernels::MinOverAxis0AtomicContigFactory; - DispatchTableBuilder - dtb4; - dtb4.populate_dispatch_table(min_over_axis0_contig_atomic_dispatch_table); -} - -} // namespace impl - -// Sum -namespace impl -{ - -using dpctl::tensor::kernels::reduction_strided_impl_fn_ptr; -static reduction_strided_impl_fn_ptr - sum_over_axis_strided_atomic_dispatch_table[td_ns::num_types] - [td_ns::num_types]; -static reduction_strided_impl_fn_ptr - sum_over_axis_strided_temps_dispatch_table[td_ns::num_types] - [td_ns::num_types]; - -using dpctl::tensor::kernels::reduction_contig_impl_fn_ptr; -static reduction_contig_impl_fn_ptr - sum_over_axis1_contig_atomic_dispatch_table[td_ns::num_types] - [td_ns::num_types]; -static reduction_contig_impl_fn_ptr - sum_over_axis0_contig_atomic_dispatch_table[td_ns::num_types] - [td_ns::num_types]; - -void populate_sum_over_axis_dispatch_tables(void) -{ - using dpctl::tensor::kernels::reduction_contig_impl_fn_ptr; - using dpctl::tensor::kernels::reduction_strided_impl_fn_ptr; - using namespace td_ns; - - using dpctl::tensor::kernels::SumOverAxisAtomicStridedFactory; - DispatchTableBuilder - dtb1; - dtb1.populate_dispatch_table(sum_over_axis_strided_atomic_dispatch_table); - - using dpctl::tensor::kernels::SumOverAxisTempsStridedFactory; - DispatchTableBuilder - dtb2; - dtb2.populate_dispatch_table(sum_over_axis_strided_temps_dispatch_table); - - using dpctl::tensor::kernels::SumOverAxis1AtomicContigFactory; - DispatchTableBuilder - dtb3; - dtb3.populate_dispatch_table(sum_over_axis1_contig_atomic_dispatch_table); - - using dpctl::tensor::kernels::SumOverAxis0AtomicContigFactory; - DispatchTableBuilder - dtb4; - dtb4.populate_dispatch_table(sum_over_axis0_contig_atomic_dispatch_table); -} - -} // namespace impl - -// Product -namespace impl -{ - -using dpctl::tensor::kernels::reduction_strided_impl_fn_ptr; -static reduction_strided_impl_fn_ptr - prod_over_axis_strided_atomic_dispatch_table[td_ns::num_types] - [td_ns::num_types]; -static reduction_strided_impl_fn_ptr - prod_over_axis_strided_temps_dispatch_table[td_ns::num_types] - [td_ns::num_types]; - -using dpctl::tensor::kernels::reduction_contig_impl_fn_ptr; -static reduction_contig_impl_fn_ptr - prod_over_axis1_contig_atomic_dispatch_table[td_ns::num_types] - [td_ns::num_types]; -static reduction_contig_impl_fn_ptr - prod_over_axis0_contig_atomic_dispatch_table[td_ns::num_types] - [td_ns::num_types]; - -void populate_prod_over_axis_dispatch_tables(void) -{ - using dpctl::tensor::kernels::reduction_contig_impl_fn_ptr; - using dpctl::tensor::kernels::reduction_strided_impl_fn_ptr; - using namespace td_ns; - - using dpctl::tensor::kernels::ProductOverAxisAtomicStridedFactory; - DispatchTableBuilder - dtb1; - dtb1.populate_dispatch_table(prod_over_axis_strided_atomic_dispatch_table); - - using dpctl::tensor::kernels::ProductOverAxisTempsStridedFactory; - DispatchTableBuilder - dtb2; - dtb2.populate_dispatch_table(prod_over_axis_strided_temps_dispatch_table); - - using dpctl::tensor::kernels::ProductOverAxis1AtomicContigFactory; - DispatchTableBuilder - dtb3; - dtb3.populate_dispatch_table(prod_over_axis1_contig_atomic_dispatch_table); - - using dpctl::tensor::kernels::ProductOverAxis0AtomicContigFactory; - DispatchTableBuilder - dtb4; - dtb4.populate_dispatch_table(prod_over_axis0_contig_atomic_dispatch_table); -} - -} // namespace impl - -// Argmax -namespace impl -{ - -using dpctl::tensor::kernels::search_reduction_strided_impl_fn_ptr; -static search_reduction_strided_impl_fn_ptr - argmax_over_axis_strided_temps_dispatch_table[td_ns::num_types] - [td_ns::num_types]; - -void populate_argmax_over_axis_dispatch_tables(void) -{ - using dpctl::tensor::kernels::search_reduction_strided_impl_fn_ptr; - using td_ns::DispatchTableBuilder; - - using dpctl::tensor::kernels::ArgmaxOverAxisTempsStridedFactory; - DispatchTableBuilder - dtb1; - dtb1.populate_dispatch_table(argmax_over_axis_strided_temps_dispatch_table); -} - -} // namespace impl - -// Argmin -namespace impl -{ - -using dpctl::tensor::kernels::search_reduction_strided_impl_fn_ptr; -static search_reduction_strided_impl_fn_ptr - argmin_over_axis_strided_temps_dispatch_table[td_ns::num_types] - [td_ns::num_types]; - -void populate_argmin_over_axis_dispatch_tables(void) -{ - using dpctl::tensor::kernels::search_reduction_strided_impl_fn_ptr; - using td_ns::DispatchTableBuilder; - - using dpctl::tensor::kernels::ArgminOverAxisTempsStridedFactory; - DispatchTableBuilder - dtb1; - dtb1.populate_dispatch_table(argmin_over_axis_strided_temps_dispatch_table); -} - -} // namespace impl - -namespace py = pybind11; - -void init_reduction_functions(py::module_ m) -{ - using arrayT = dpctl::tensor::usm_ndarray; - using event_vecT = std::vector; - - namespace impl = dpctl::tensor::py_internal::impl; - - using dpctl::tensor::py_internal::py_reduction_dtype_supported; - using dpctl::tensor::py_internal::py_reduction_over_axis; - - using dpctl::tensor::py_internal::check_atomic_support; - using dpctl::tensor::py_internal::fixed_decision; - - // MAX - { - using dpctl::tensor::py_internal::impl:: - populate_max_over_axis_dispatch_tables; - populate_max_over_axis_dispatch_tables(); - using impl::max_over_axis0_contig_atomic_dispatch_table; - using impl::max_over_axis1_contig_atomic_dispatch_table; - using impl::max_over_axis_strided_atomic_dispatch_table; - using impl::max_over_axis_strided_temps_dispatch_table; - - const auto &check_atomic_support_size4 = - check_atomic_support; - const auto &check_atomic_support_size8 = - check_atomic_support; - - auto max_pyapi = [&](const arrayT &src, int trailing_dims_to_reduce, - const arrayT &dst, sycl::queue &exec_q, - const event_vecT &depends = {}) { - return py_reduction_over_axis( - src, trailing_dims_to_reduce, dst, exec_q, depends, - max_over_axis_strided_atomic_dispatch_table, - max_over_axis_strided_temps_dispatch_table, - max_over_axis0_contig_atomic_dispatch_table, - max_over_axis1_contig_atomic_dispatch_table, - check_atomic_support_size4, check_atomic_support_size8); - }; - m.def("_max_over_axis", max_pyapi, "", py::arg("src"), - py::arg("trailing_dims_to_reduce"), py::arg("dst"), - py::arg("sycl_queue"), py::arg("depends") = py::list()); - } - - // MIN - { - using dpctl::tensor::py_internal::impl:: - populate_min_over_axis_dispatch_tables; - populate_min_over_axis_dispatch_tables(); - using impl::min_over_axis0_contig_atomic_dispatch_table; - using impl::min_over_axis1_contig_atomic_dispatch_table; - using impl::min_over_axis_strided_atomic_dispatch_table; - using impl::min_over_axis_strided_temps_dispatch_table; - - const auto &check_atomic_support_size4 = - check_atomic_support; - const auto &check_atomic_support_size8 = - check_atomic_support; - - auto min_pyapi = [&](const arrayT &src, int trailing_dims_to_reduce, - const arrayT &dst, sycl::queue &exec_q, - const event_vecT &depends = {}) { - return py_reduction_over_axis( - src, trailing_dims_to_reduce, dst, exec_q, depends, - min_over_axis_strided_atomic_dispatch_table, - min_over_axis_strided_temps_dispatch_table, - min_over_axis0_contig_atomic_dispatch_table, - min_over_axis1_contig_atomic_dispatch_table, - check_atomic_support_size4, check_atomic_support_size8); - }; - m.def("_min_over_axis", min_pyapi, "", py::arg("src"), - py::arg("trailing_dims_to_reduce"), py::arg("dst"), - py::arg("sycl_queue"), py::arg("depends") = py::list()); - } - - // SUM - { - using dpctl::tensor::py_internal::impl:: - populate_sum_over_axis_dispatch_tables; - populate_sum_over_axis_dispatch_tables(); - using impl::sum_over_axis0_contig_atomic_dispatch_table; - using impl::sum_over_axis1_contig_atomic_dispatch_table; - using impl::sum_over_axis_strided_atomic_dispatch_table; - using impl::sum_over_axis_strided_temps_dispatch_table; - - const auto &check_atomic_support_size4 = - check_atomic_support; - const auto &check_atomic_support_size8 = - check_atomic_support; - - auto sum_pyapi = [&](const arrayT &src, int trailing_dims_to_reduce, - const arrayT &dst, sycl::queue &exec_q, - const event_vecT &depends = {}) { - return py_reduction_over_axis( - src, trailing_dims_to_reduce, dst, exec_q, depends, - sum_over_axis_strided_atomic_dispatch_table, - sum_over_axis_strided_temps_dispatch_table, - sum_over_axis0_contig_atomic_dispatch_table, - sum_over_axis1_contig_atomic_dispatch_table, - check_atomic_support_size4, check_atomic_support_size8); - }; - m.def("_sum_over_axis", sum_pyapi, "", py::arg("src"), - py::arg("trailing_dims_to_reduce"), py::arg("dst"), - py::arg("sycl_queue"), py::arg("depends") = py::list()); - - auto sum_dtype_supported = - [&](const py::dtype &input_dtype, const py::dtype &output_dtype, - const std::string &dst_usm_type, sycl::queue &q) { - return py_reduction_dtype_supported( - input_dtype, output_dtype, dst_usm_type, q, - sum_over_axis_strided_atomic_dispatch_table, - sum_over_axis_strided_temps_dispatch_table, - check_atomic_support_size4, check_atomic_support_size8); - }; - m.def("_sum_over_axis_dtype_supported", sum_dtype_supported, "", - py::arg("arg_dtype"), py::arg("out_dtype"), - py::arg("dst_usm_type"), py::arg("sycl_queue")); - } - - // PROD - { - using dpctl::tensor::py_internal::impl:: - populate_prod_over_axis_dispatch_tables; - populate_prod_over_axis_dispatch_tables(); - using impl::prod_over_axis0_contig_atomic_dispatch_table; - using impl::prod_over_axis1_contig_atomic_dispatch_table; - using impl::prod_over_axis_strided_atomic_dispatch_table; - using impl::prod_over_axis_strided_temps_dispatch_table; - - const auto &check_atomic_support_size4 = - check_atomic_support; - const auto &check_atomic_support_size8 = - check_atomic_support; - - auto prod_pyapi = [&](const arrayT &src, int trailing_dims_to_reduce, - const arrayT &dst, sycl::queue &exec_q, - const event_vecT &depends = {}) { - return py_reduction_over_axis( - src, trailing_dims_to_reduce, dst, exec_q, depends, - prod_over_axis_strided_atomic_dispatch_table, - prod_over_axis_strided_temps_dispatch_table, - prod_over_axis0_contig_atomic_dispatch_table, - prod_over_axis1_contig_atomic_dispatch_table, - check_atomic_support_size4, check_atomic_support_size8); - }; - m.def("_prod_over_axis", prod_pyapi, "", py::arg("src"), - py::arg("trailing_dims_to_reduce"), py::arg("dst"), - py::arg("sycl_queue"), py::arg("depends") = py::list()); - - auto prod_dtype_supported = - [&](const py::dtype &input_dtype, const py::dtype &output_dtype, - const std::string &dst_usm_type, sycl::queue &q) { - return py_reduction_dtype_supported( - input_dtype, output_dtype, dst_usm_type, q, - prod_over_axis_strided_atomic_dispatch_table, - prod_over_axis_strided_temps_dispatch_table, - check_atomic_support_size4, check_atomic_support_size8); - }; - m.def("_prod_over_axis_dtype_supported", prod_dtype_supported, "", - py::arg("arg_dtype"), py::arg("out_dtype"), - py::arg("dst_usm_type"), py::arg("sycl_queue")); - } - - // ARGMAX - { - using dpctl::tensor::py_internal::impl:: - populate_argmax_over_axis_dispatch_tables; - populate_argmax_over_axis_dispatch_tables(); - using impl::argmax_over_axis_strided_temps_dispatch_table; - - auto argmax_pyapi = [&](const arrayT &src, int trailing_dims_to_reduce, - const arrayT &dst, sycl::queue &exec_q, - const event_vecT &depends = {}) { - using dpctl::tensor::py_internal::py_search_over_axis; - return py_search_over_axis( - src, trailing_dims_to_reduce, dst, exec_q, depends, - argmax_over_axis_strided_temps_dispatch_table); - }; - m.def("_argmax_over_axis", argmax_pyapi, "", py::arg("src"), - py::arg("trailing_dims_to_reduce"), py::arg("dst"), - py::arg("sycl_queue"), py::arg("depends") = py::list()); - } - - // ARGMIN - { - using dpctl::tensor::py_internal::impl:: - populate_argmin_over_axis_dispatch_tables; - populate_argmin_over_axis_dispatch_tables(); - using impl::argmin_over_axis_strided_temps_dispatch_table; - - auto argmin_pyapi = [&](const arrayT &src, int trailing_dims_to_reduce, - const arrayT &dst, sycl::queue &exec_q, - const event_vecT &depends = {}) { - using dpctl::tensor::py_internal::py_search_over_axis; - return py_search_over_axis( - src, trailing_dims_to_reduce, dst, exec_q, depends, - argmin_over_axis_strided_temps_dispatch_table); - }; - m.def("_argmin_over_axis", argmin_pyapi, "", py::arg("src"), - py::arg("trailing_dims_to_reduce"), py::arg("dst"), - py::arg("sycl_queue"), py::arg("depends") = py::list()); - } -} - -} // namespace py_internal -} // namespace tensor -} // namespace dpctl diff --git a/dpctl/tensor/libtensor/source/reduction_over_axis.hpp b/dpctl/tensor/libtensor/source/reduction_over_axis.hpp deleted file mode 100644 index 1a9cb6f5e7..0000000000 --- a/dpctl/tensor/libtensor/source/reduction_over_axis.hpp +++ /dev/null @@ -1,689 +0,0 @@ -//===----------- Implementation of _tensor_impl module ---------*-C++-*-/===// -// -// Data Parallel Control (dpctl) -// -// Copyright 2020-2023 Intel Corporation -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// -//===----------------------------------------------------------------------===// -/// -/// \file -/// This file defines functions of dpctl.tensor._tensor_impl extensions, -/// specifically functions for reductions. -//===----------------------------------------------------------------------===// - -#pragma once - -#include -#include -#include -#include -#include -#include - -#include "dpctl4pybind11.hpp" -#include -#include -#include - -#include "kernels/reductions.hpp" -#include "simplify_iteration_space.hpp" -#include "utils/memory_overlap.hpp" -#include "utils/offset_utils.hpp" -#include "utils/type_dispatch.hpp" - -namespace dpctl -{ -namespace tensor -{ -namespace py_internal -{ - -template -bool check_atomic_support(const sycl::queue &exec_q, - sycl::usm::alloc usm_alloc_type) -{ - bool supports_atomics = false; - - const sycl::device &dev = exec_q.get_device(); - - if constexpr (require_atomic64) { - if (!dev.has(sycl::aspect::atomic64)) - return false; - } - - switch (usm_alloc_type) { - case sycl::usm::alloc::shared: - supports_atomics = dev.has(sycl::aspect::usm_atomic_shared_allocations); - break; - case sycl::usm::alloc::host: - supports_atomics = dev.has(sycl::aspect::usm_atomic_host_allocations); - break; - case sycl::usm::alloc::device: - supports_atomics = true; - break; - default: - supports_atomics = false; - } - - return supports_atomics; -} - -template -bool fixed_decision(const sycl::queue &, sycl::usm::alloc) -{ - return return_value; -} - -/* ====================== dtype supported ======================== */ - -template -bool py_reduction_dtype_supported( - const py::dtype &input_dtype, - const py::dtype &output_dtype, - const std::string &dst_usm_type, - sycl::queue &q, - const fnT &atomic_dispatch_table, - const fnT &temps_dispatch_table, - const CheckAtomicSupportFnT &check_atomic_support_size4, - const CheckAtomicSupportFnT &check_atomic_support_size8) -{ - int arg_tn = - input_dtype.num(); // NumPy type numbers are the same as in dpctl - int out_tn = - output_dtype.num(); // NumPy type numbers are the same as in dpctl - int arg_typeid = -1; - int out_typeid = -1; - - auto array_types = td_ns::usm_ndarray_types(); - - try { - arg_typeid = array_types.typenum_to_lookup_id(arg_tn); - out_typeid = array_types.typenum_to_lookup_id(out_tn); - } catch (const std::exception &e) { - throw py::value_error(e.what()); - } - - if (arg_typeid < 0 || arg_typeid >= td_ns::num_types || out_typeid < 0 || - out_typeid >= td_ns::num_types) - { - throw std::runtime_error("Reduction type support check: lookup failed"); - } - - // remove_all_extents gets underlying type of table - using fn_ptrT = typename std::remove_all_extents::type; - fn_ptrT fn = nullptr; - - sycl::usm::alloc kind = sycl::usm::alloc::unknown; - - if (dst_usm_type == "device") { - kind = sycl::usm::alloc::device; - } - else if (dst_usm_type == "shared") { - kind = sycl::usm::alloc::shared; - } - else if (dst_usm_type == "host") { - kind = sycl::usm::alloc::host; - } - else { - throw py::value_error("Unrecognized `dst_usm_type` argument."); - } - - bool supports_atomics = false; - - switch (output_dtype.itemsize()) { - case sizeof(float): - { - supports_atomics = check_atomic_support_size4(q, kind); - } break; - case sizeof(double): - { - supports_atomics = check_atomic_support_size8(q, kind); - } break; - } - - if (supports_atomics) { - fn = atomic_dispatch_table[arg_typeid][out_typeid]; - } - - if (fn == nullptr) { - // use slower reduction implementation using temporaries - fn = temps_dispatch_table[arg_typeid][out_typeid]; - } - - return (fn != nullptr); -} - -/* ==================== Generic reductions ====================== */ - -template -std::pair py_reduction_over_axis( - const dpctl::tensor::usm_ndarray &src, - int trailing_dims_to_reduce, // comp over this many trailing indexes - const dpctl::tensor::usm_ndarray &dst, - sycl::queue &exec_q, - const std::vector &depends, - const strided_fnT &atomic_dispatch_table, - const strided_fnT &temps_dispatch_table, - const contig_fnT &axis0_dispatch_table, - const contig_fnT &axis1_dispatch_table, - const SupportAtomicFnT &check_atomic_support_size4, - const SupportAtomicFnT &check_atomic_support_size8) -{ - int src_nd = src.get_ndim(); - int iteration_nd = src_nd - trailing_dims_to_reduce; - if (trailing_dims_to_reduce <= 0 || iteration_nd < 0) { - throw py::value_error("Trailing_dim_to_reduce must be positive, but no " - "greater than rank of the array being reduced"); - } - - int dst_nd = dst.get_ndim(); - if (dst_nd != iteration_nd) { - throw py::value_error("Destination array rank does not match input " - "array rank and number of reduced dimensions"); - } - - const py::ssize_t *src_shape_ptr = src.get_shape_raw(); - const py::ssize_t *dst_shape_ptr = dst.get_shape_raw(); - - bool same_shapes = true; - for (int i = 0; same_shapes && (i < dst_nd); ++i) { - same_shapes = same_shapes && (src_shape_ptr[i] == dst_shape_ptr[i]); - } - - if (!same_shapes) { - throw py::value_error("Destination shape does not match unreduced " - "dimensions of the input shape"); - } - - if (!dpctl::utils::queues_are_compatible(exec_q, {src, dst})) { - throw py::value_error( - "Execution queue is not compatible with allocation queues"); - } - - size_t dst_nelems = dst.get_size(); - - size_t reduction_nelems(1); - for (int i = dst_nd; i < src_nd; ++i) { - reduction_nelems *= static_cast(src_shape_ptr[i]); - } - - // check that dst and src do not overlap - auto const &overlap = dpctl::tensor::overlap::MemoryOverlap(); - if (overlap(src, dst)) { - throw py::value_error("Arrays index overlapping segments of memory"); - } - - // destination must be ample enough to accommodate all elements - { - auto dst_offsets = dst.get_minmax_offsets(); - size_t range = - static_cast(dst_offsets.second - dst_offsets.first); - if (range + 1 < dst_nelems) { - throw py::value_error( - "Destination array can not accommodate all the " - "elements of source array."); - } - } - - int src_typenum = src.get_typenum(); - int dst_typenum = dst.get_typenum(); - - namespace td_ns = dpctl::tensor::type_dispatch; - const auto &array_types = td_ns::usm_ndarray_types(); - int src_typeid = array_types.typenum_to_lookup_id(src_typenum); - int dst_typeid = array_types.typenum_to_lookup_id(dst_typenum); - - int dst_itemsize = dst.get_elemsize(); - bool supports_atomics = false; - - switch (dst_itemsize) { - case sizeof(float): - { - void *data_ptr = dst.get_data(); - const auto &ctx = exec_q.get_context(); - auto usm_type = sycl::get_pointer_type(data_ptr, ctx); - supports_atomics = check_atomic_support_size4(exec_q, usm_type); - } break; - case sizeof(double): - { - void *data_ptr = dst.get_data(); - const auto &ctx = exec_q.get_context(); - auto usm_type = sycl::get_pointer_type(data_ptr, ctx); - - supports_atomics = check_atomic_support_size8(exec_q, usm_type); - } break; - } - - // handle special case when both reduction and iteration are 1D contiguous - // and can be done with atomics - if (supports_atomics) { - bool is_src_c_contig = src.is_c_contiguous(); - bool is_dst_c_contig = dst.is_c_contiguous(); - bool is_src_f_contig = src.is_f_contiguous(); - - if ((is_src_c_contig && is_dst_c_contig) || - (is_src_f_contig && dst_nelems == 1)) - { - auto fn = axis1_dispatch_table[src_typeid][dst_typeid]; - - if (fn != nullptr) { - size_t iter_nelems = dst_nelems; - - constexpr py::ssize_t zero_offset = 0; - - sycl::event reduction_over_axis_contig_ev = - fn(exec_q, iter_nelems, reduction_nelems, src.get_data(), - dst.get_data(), - zero_offset, // iteration_src_offset - zero_offset, // iteration_dst_offset - zero_offset, // reduction_src_offset - depends); - - sycl::event keep_args_event = dpctl::utils::keep_args_alive( - exec_q, {src, dst}, {reduction_over_axis_contig_ev}); - - return std::make_pair(keep_args_event, - reduction_over_axis_contig_ev); - } - } - else if (is_src_f_contig && - ((is_dst_c_contig && dst_nd == 1) || dst.is_f_contiguous())) - { - auto fn = axis0_dispatch_table[src_typeid][dst_typeid]; - if (fn != nullptr) { - size_t iter_nelems = dst_nelems; - - constexpr py::ssize_t zero_offset = 0; - - sycl::event reduction_over_axis_contig_ev = - fn(exec_q, iter_nelems, reduction_nelems, src.get_data(), - dst.get_data(), - zero_offset, // iteration_src_offset - zero_offset, // iteration_dst_offset - zero_offset, // reduction_src_offset - depends); - - sycl::event keep_args_event = dpctl::utils::keep_args_alive( - exec_q, {src, dst}, {reduction_over_axis_contig_ev}); - - return std::make_pair(keep_args_event, - reduction_over_axis_contig_ev); - } - } - } - - using dpctl::tensor::py_internal::simplify_iteration_space; - using dpctl::tensor::py_internal::simplify_iteration_space_1; - - auto const &src_shape_vecs = src.get_shape_vector(); - auto const &src_strides_vecs = src.get_strides_vector(); - auto const &dst_strides_vecs = dst.get_strides_vector(); - - int reduction_nd = trailing_dims_to_reduce; - const py::ssize_t *reduction_shape_ptr = src_shape_ptr + dst_nd; - using shT = std::vector; - shT reduction_src_strides(std::begin(src_strides_vecs) + dst_nd, - std::end(src_strides_vecs)); - - shT simplified_reduction_shape; - shT simplified_reduction_src_strides; - py::ssize_t reduction_src_offset(0); - - simplify_iteration_space_1( - reduction_nd, reduction_shape_ptr, reduction_src_strides, - // output - simplified_reduction_shape, simplified_reduction_src_strides, - reduction_src_offset); - - const py::ssize_t *iteration_shape_ptr = src_shape_ptr; - - shT iteration_src_strides(std::begin(src_strides_vecs), - std::begin(src_strides_vecs) + iteration_nd); - shT const &iteration_dst_strides = dst_strides_vecs; - - shT simplified_iteration_shape; - shT simplified_iteration_src_strides; - shT simplified_iteration_dst_strides; - py::ssize_t iteration_src_offset(0); - py::ssize_t iteration_dst_offset(0); - - if (iteration_nd == 0) { - if (dst_nelems != 1) { - throw std::runtime_error("iteration_nd == 0, but dst_nelems != 1"); - } - iteration_nd = 1; - simplified_iteration_shape.push_back(1); - simplified_iteration_src_strides.push_back(0); - simplified_iteration_dst_strides.push_back(0); - } - else { - simplify_iteration_space(iteration_nd, iteration_shape_ptr, - iteration_src_strides, iteration_dst_strides, - // output - simplified_iteration_shape, - simplified_iteration_src_strides, - simplified_iteration_dst_strides, - iteration_src_offset, iteration_dst_offset); - } - - if (supports_atomics && (reduction_nd == 1) && (iteration_nd == 1)) { - bool mat_reduce_over_axis1 = false; - bool mat_reduce_over_axis0 = false; - bool array_reduce_all_elems = false; - size_t iter_nelems = dst_nelems; - - if (simplified_reduction_src_strides[0] == 1) { - array_reduce_all_elems = (simplified_iteration_shape[0] == 1); - mat_reduce_over_axis1 = - (simplified_iteration_dst_strides[0] == 1) && - (static_cast(simplified_iteration_src_strides[0]) == - reduction_nelems); - } - else if (static_cast(simplified_reduction_src_strides[0]) == - iter_nelems) - { - mat_reduce_over_axis0 = - (simplified_iteration_dst_strides[0] == 1) && - (simplified_iteration_src_strides[0] == 1); - } - - if (mat_reduce_over_axis1 || array_reduce_all_elems) { - auto fn = axis1_dispatch_table[src_typeid][dst_typeid]; - if (fn != nullptr) { - sycl::event reduction_over_axis1_contig_ev = - fn(exec_q, iter_nelems, reduction_nelems, src.get_data(), - dst.get_data(), iteration_src_offset, - iteration_dst_offset, reduction_src_offset, depends); - - sycl::event keep_args_event = dpctl::utils::keep_args_alive( - exec_q, {src, dst}, {reduction_over_axis1_contig_ev}); - - return std::make_pair(keep_args_event, - reduction_over_axis1_contig_ev); - } - } - else if (mat_reduce_over_axis0) { - auto fn = axis0_dispatch_table[src_typeid][dst_typeid]; - if (fn != nullptr) { - sycl::event reduction_over_axis0_contig_ev = - fn(exec_q, iter_nelems, reduction_nelems, src.get_data(), - dst.get_data(), iteration_src_offset, - iteration_dst_offset, reduction_src_offset, depends); - - sycl::event keep_args_event = dpctl::utils::keep_args_alive( - exec_q, {src, dst}, {reduction_over_axis0_contig_ev}); - - return std::make_pair(keep_args_event, - reduction_over_axis0_contig_ev); - } - } - } - - // remove_all_extents gets underlying type of table - using strided_fn_ptr_T = - typename std::remove_all_extents::type; - strided_fn_ptr_T fn = nullptr; - - if (supports_atomics) { - fn = atomic_dispatch_table[src_typeid][dst_typeid]; - } - - if (fn == nullptr) { - // use slower reduction implementation using temporaries - fn = temps_dispatch_table[src_typeid][dst_typeid]; - if (fn == nullptr) { - throw std::runtime_error("Datatypes are not supported"); - } - } - - std::vector host_task_events{}; - - using dpctl::tensor::offset_utils::device_allocate_and_pack; - - const auto &arrays_metainfo_packing_triple_ = - device_allocate_and_pack( - exec_q, host_task_events, - // iteration metadata - simplified_iteration_shape, simplified_iteration_src_strides, - simplified_iteration_dst_strides, - // reduction metadata - simplified_reduction_shape, simplified_reduction_src_strides); - py::ssize_t *temp_allocation_ptr = - std::get<0>(arrays_metainfo_packing_triple_); - if (temp_allocation_ptr == nullptr) { - throw std::runtime_error("Unable to allocate memory on device"); - } - const auto ©_metadata_ev = std::get<2>(arrays_metainfo_packing_triple_); - - py::ssize_t *iter_shape_and_strides = temp_allocation_ptr; - py::ssize_t *reduction_shape_stride = - temp_allocation_ptr + 3 * simplified_iteration_shape.size(); - - std::vector all_deps; - all_deps.reserve(depends.size() + 1); - all_deps.resize(depends.size()); - std::copy(depends.begin(), depends.end(), all_deps.begin()); - all_deps.push_back(copy_metadata_ev); - - auto reduction_ev = - fn(exec_q, dst_nelems, reduction_nelems, src.get_data(), dst.get_data(), - iteration_nd, iter_shape_and_strides, iteration_src_offset, - iteration_dst_offset, - reduction_nd, // number dimensions being reduced - reduction_shape_stride, reduction_src_offset, all_deps); - - sycl::event temp_cleanup_ev = exec_q.submit([&](sycl::handler &cgh) { - cgh.depends_on(reduction_ev); - const auto &ctx = exec_q.get_context(); - cgh.host_task([ctx, temp_allocation_ptr] { - sycl::free(temp_allocation_ptr, ctx); - }); - }); - host_task_events.push_back(temp_cleanup_ev); - - sycl::event keep_args_event = - dpctl::utils::keep_args_alive(exec_q, {src, dst}, host_task_events); - - return std::make_pair(keep_args_event, reduction_ev); -} - -/* ==================== Search reductions ====================== */ - -template -std::pair py_search_over_axis( - const dpctl::tensor::usm_ndarray &src, - int trailing_dims_to_reduce, // comp over this many trailing indexes - const dpctl::tensor::usm_ndarray &dst, - sycl::queue &exec_q, - const std::vector &depends, - const fn_tableT &dispatch_table) -{ - int src_nd = src.get_ndim(); - int iteration_nd = src_nd - trailing_dims_to_reduce; - if (trailing_dims_to_reduce <= 0 || iteration_nd < 0) { - throw py::value_error("Trailing_dim_to_reduce must be positive, but no " - "greater than rank of the array being reduced"); - } - - int dst_nd = dst.get_ndim(); - if (dst_nd != iteration_nd) { - throw py::value_error("Destination array rank does not match input " - "array rank and number of reduced dimensions"); - } - - const py::ssize_t *src_shape_ptr = src.get_shape_raw(); - const py::ssize_t *dst_shape_ptr = dst.get_shape_raw(); - - bool same_shapes = true; - for (int i = 0; same_shapes && (i < dst_nd); ++i) { - same_shapes = same_shapes && (src_shape_ptr[i] == dst_shape_ptr[i]); - } - - if (!same_shapes) { - throw py::value_error("Destination shape does not match unreduced " - "dimensions of the input shape"); - } - - if (!dpctl::utils::queues_are_compatible(exec_q, {src, dst})) { - throw py::value_error( - "Execution queue is not compatible with allocation queues"); - } - - size_t dst_nelems = dst.get_size(); - - size_t reduction_nelems(1); - for (int i = dst_nd; i < src_nd; ++i) { - reduction_nelems *= static_cast(src_shape_ptr[i]); - } - - // check that dst and src do not overlap - auto const &overlap = dpctl::tensor::overlap::MemoryOverlap(); - if (overlap(src, dst)) { - throw py::value_error("Arrays index overlapping segments of memory"); - } - - // destination must be ample enough to accommodate all elements - { - auto dst_offsets = dst.get_minmax_offsets(); - size_t range = - static_cast(dst_offsets.second - dst_offsets.first); - if (range + 1 < dst_nelems) { - throw py::value_error( - "Destination array can not accommodate all the " - "elements of source array."); - } - } - - int src_typenum = src.get_typenum(); - int dst_typenum = dst.get_typenum(); - - namespace td_ns = dpctl::tensor::type_dispatch; - const auto &array_types = td_ns::usm_ndarray_types(); - int src_typeid = array_types.typenum_to_lookup_id(src_typenum); - int dst_typeid = array_types.typenum_to_lookup_id(dst_typenum); - - using dpctl::tensor::py_internal::simplify_iteration_space; - using dpctl::tensor::py_internal::simplify_iteration_space_1; - - auto const &src_shape_vecs = src.get_shape_vector(); - auto const &src_strides_vecs = src.get_strides_vector(); - auto const &dst_strides_vecs = dst.get_strides_vector(); - - int reduction_nd = trailing_dims_to_reduce; - const py::ssize_t *reduction_shape_ptr = src_shape_ptr + dst_nd; - using shT = std::vector; - shT reduction_src_strides(std::begin(src_strides_vecs) + dst_nd, - std::end(src_strides_vecs)); - - shT compact_reduction_shape; - shT compact_reduction_src_strides; - py::ssize_t reduction_src_offset(0); - - compact_iteration_space( - reduction_nd, reduction_shape_ptr, reduction_src_strides, - // output - compact_reduction_shape, compact_reduction_src_strides); - - const py::ssize_t *iteration_shape_ptr = src_shape_ptr; - - shT iteration_src_strides(std::begin(src_strides_vecs), - std::begin(src_strides_vecs) + iteration_nd); - shT const &iteration_dst_strides = dst_strides_vecs; - - shT simplified_iteration_shape; - shT simplified_iteration_src_strides; - shT simplified_iteration_dst_strides; - py::ssize_t iteration_src_offset(0); - py::ssize_t iteration_dst_offset(0); - - if (iteration_nd == 0) { - if (dst_nelems != 1) { - throw std::runtime_error("iteration_nd == 0, but dst_nelems != 1"); - } - iteration_nd = 1; - simplified_iteration_shape.push_back(1); - simplified_iteration_src_strides.push_back(0); - simplified_iteration_dst_strides.push_back(0); - } - else { - simplify_iteration_space(iteration_nd, iteration_shape_ptr, - iteration_src_strides, iteration_dst_strides, - // output - simplified_iteration_shape, - simplified_iteration_src_strides, - simplified_iteration_dst_strides, - iteration_src_offset, iteration_dst_offset); - } - - auto fn = dispatch_table[src_typeid][dst_typeid]; - if (fn == nullptr) { - throw std::runtime_error("Datatypes are not supported"); - } - - std::vector host_task_events{}; - - using dpctl::tensor::offset_utils::device_allocate_and_pack; - - const auto &arrays_metainfo_packing_triple_ = - device_allocate_and_pack( - exec_q, host_task_events, - // iteration metadata - simplified_iteration_shape, simplified_iteration_src_strides, - simplified_iteration_dst_strides, - // reduction metadata - compact_reduction_shape, compact_reduction_src_strides); - py::ssize_t *temp_allocation_ptr = - std::get<0>(arrays_metainfo_packing_triple_); - if (temp_allocation_ptr == nullptr) { - throw std::runtime_error("Unable to allocate memory on device"); - } - const auto ©_metadata_ev = std::get<2>(arrays_metainfo_packing_triple_); - - py::ssize_t *iter_shape_and_strides = temp_allocation_ptr; - py::ssize_t *reduction_shape_stride = - temp_allocation_ptr + 3 * simplified_iteration_shape.size(); - - std::vector all_deps; - all_deps.reserve(depends.size() + 1); - all_deps.resize(depends.size()); - std::copy(depends.begin(), depends.end(), all_deps.begin()); - all_deps.push_back(copy_metadata_ev); - - auto comp_ev = fn(exec_q, dst_nelems, reduction_nelems, src.get_data(), - dst.get_data(), iteration_nd, iter_shape_and_strides, - iteration_src_offset, iteration_dst_offset, - reduction_nd, // number dimensions being reduced - reduction_shape_stride, reduction_src_offset, all_deps); - - sycl::event temp_cleanup_ev = exec_q.submit([&](sycl::handler &cgh) { - cgh.depends_on(comp_ev); - const auto &ctx = exec_q.get_context(); - cgh.host_task([ctx, temp_allocation_ptr] { - sycl::free(temp_allocation_ptr, ctx); - }); - }); - host_task_events.push_back(temp_cleanup_ev); - - sycl::event keep_args_event = - dpctl::utils::keep_args_alive(exec_q, {src, dst}, host_task_events); - - return std::make_pair(keep_args_event, comp_ev); -} - -extern void init_reduction_functions(py::module_ m); - -} // namespace py_internal -} // namespace tensor -} // namespace dpctl diff --git a/dpctl/tensor/libtensor/source/reductions/argmax.cpp b/dpctl/tensor/libtensor/source/reductions/argmax.cpp new file mode 100644 index 0000000000..1d83bf9c2d --- /dev/null +++ b/dpctl/tensor/libtensor/source/reductions/argmax.cpp @@ -0,0 +1,119 @@ +//===-- ------------ Implementation of _tensor_impl module ----*-C++-*-/===// +// +// Data Parallel Control (dpctl) +// +// Copyright 2020-2023 Intel Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//===--------------------------------------------------------------------===// +/// +/// \file +/// This file defines functions of dpctl.tensor._tensor_impl extensions +//===--------------------------------------------------------------------===// + +#include "dpctl4pybind11.hpp" +#include +#include +#include +#include +#include + +#include "kernels/reductions.hpp" +#include "reduction_over_axis.hpp" +#include "utils/type_dispatch.hpp" + +namespace py = pybind11; + +namespace dpctl +{ +namespace tensor +{ +namespace py_internal +{ + +namespace td_ns = dpctl::tensor::type_dispatch; + +namespace impl +{ + +using dpctl::tensor::kernels::search_strided_impl_fn_ptr; +static search_strided_impl_fn_ptr + argmax_over_axis_strided_temps_dispatch_table[td_ns::num_types] + [td_ns::num_types]; + +using dpctl::tensor::kernels::search_contig_impl_fn_ptr; +static search_contig_impl_fn_ptr + argmax_over_axis1_contig_temps_dispatch_table[td_ns::num_types] + [td_ns::num_types]; +using dpctl::tensor::kernels::search_contig_impl_fn_ptr; +static search_contig_impl_fn_ptr + argmax_over_axis0_contig_temps_dispatch_table[td_ns::num_types] + [td_ns::num_types]; + +void populate_argmax_over_axis_dispatch_tables(void) +{ + using dpctl::tensor::kernels::search_strided_impl_fn_ptr; + using td_ns::DispatchTableBuilder; + + using dpctl::tensor::kernels::ArgmaxOverAxisTempsStridedFactory; + DispatchTableBuilder + dtb1; + dtb1.populate_dispatch_table(argmax_over_axis_strided_temps_dispatch_table); + + using dpctl::tensor::kernels::ArgmaxOverAxis1TempsContigFactory; + DispatchTableBuilder + dtb2; + dtb2.populate_dispatch_table(argmax_over_axis1_contig_temps_dispatch_table); + + using dpctl::tensor::kernels::ArgmaxOverAxis0TempsContigFactory; + DispatchTableBuilder + dtb3; + dtb3.populate_dispatch_table(argmax_over_axis0_contig_temps_dispatch_table); +} + +} // namespace impl + +void init_argmax(py::module_ m) +{ + using arrayT = dpctl::tensor::usm_ndarray; + using event_vecT = std::vector; + { + using impl::populate_argmax_over_axis_dispatch_tables; + populate_argmax_over_axis_dispatch_tables(); + using impl::argmax_over_axis0_contig_temps_dispatch_table; + using impl::argmax_over_axis1_contig_temps_dispatch_table; + using impl::argmax_over_axis_strided_temps_dispatch_table; + + auto argmax_pyapi = [&](const arrayT &src, int trailing_dims_to_reduce, + const arrayT &dst, sycl::queue &exec_q, + const event_vecT &depends = {}) { + using dpctl::tensor::py_internal::py_search_over_axis; + return py_search_over_axis( + src, trailing_dims_to_reduce, dst, exec_q, depends, + argmax_over_axis_strided_temps_dispatch_table, + argmax_over_axis0_contig_temps_dispatch_table, + argmax_over_axis1_contig_temps_dispatch_table); + }; + m.def("_argmax_over_axis", argmax_pyapi, "", py::arg("src"), + py::arg("trailing_dims_to_reduce"), py::arg("dst"), + py::arg("sycl_queue"), py::arg("depends") = py::list()); + } +} + +} // namespace py_internal +} // namespace tensor +} // namespace dpctl diff --git a/dpctl/tensor/libtensor/source/reductions/argmax.hpp b/dpctl/tensor/libtensor/source/reductions/argmax.hpp new file mode 100644 index 0000000000..9958396b43 --- /dev/null +++ b/dpctl/tensor/libtensor/source/reductions/argmax.hpp @@ -0,0 +1,41 @@ +//===-- ------------ Implementation of _tensor_impl module ----*-C++-*-/===// +// +// Data Parallel Control (dpctl) +// +// Copyright 2020-2023 Intel Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//===--------------------------------------------------------------------===// +/// +/// \file +/// This file defines functions of dpctl.tensor._tensor_impl extensions +//===--------------------------------------------------------------------===// + +#pragma once +#include + +namespace py = pybind11; + +namespace dpctl +{ +namespace tensor +{ +namespace py_internal +{ + +extern void init_argmax(py::module_ m); + +} // namespace py_internal +} // namespace tensor +} // namespace dpctl diff --git a/dpctl/tensor/libtensor/source/reductions/argmin.cpp b/dpctl/tensor/libtensor/source/reductions/argmin.cpp new file mode 100644 index 0000000000..c6469e6864 --- /dev/null +++ b/dpctl/tensor/libtensor/source/reductions/argmin.cpp @@ -0,0 +1,119 @@ +//===-- ------------ Implementation of _tensor_impl module ----*-C++-*-/===// +// +// Data Parallel Control (dpctl) +// +// Copyright 2020-2023 Intel Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//===--------------------------------------------------------------------===// +/// +/// \file +/// This file defines functions of dpctl.tensor._tensor_impl extensions +//===--------------------------------------------------------------------===// + +#include "dpctl4pybind11.hpp" +#include +#include +#include +#include +#include + +#include "kernels/reductions.hpp" +#include "reduction_over_axis.hpp" +#include "utils/type_dispatch.hpp" + +namespace py = pybind11; + +namespace dpctl +{ +namespace tensor +{ +namespace py_internal +{ + +namespace td_ns = dpctl::tensor::type_dispatch; + +namespace impl +{ + +using dpctl::tensor::kernels::search_strided_impl_fn_ptr; +static search_strided_impl_fn_ptr + argmin_over_axis_strided_temps_dispatch_table[td_ns::num_types] + [td_ns::num_types]; + +using dpctl::tensor::kernels::search_contig_impl_fn_ptr; +static search_contig_impl_fn_ptr + argmin_over_axis1_contig_temps_dispatch_table[td_ns::num_types] + [td_ns::num_types]; +using dpctl::tensor::kernels::search_contig_impl_fn_ptr; +static search_contig_impl_fn_ptr + argmin_over_axis0_contig_temps_dispatch_table[td_ns::num_types] + [td_ns::num_types]; + +void populate_argmin_over_axis_dispatch_tables(void) +{ + using dpctl::tensor::kernels::search_strided_impl_fn_ptr; + using td_ns::DispatchTableBuilder; + + using dpctl::tensor::kernels::ArgminOverAxisTempsStridedFactory; + DispatchTableBuilder + dtb1; + dtb1.populate_dispatch_table(argmin_over_axis_strided_temps_dispatch_table); + + using dpctl::tensor::kernels::ArgminOverAxis1TempsContigFactory; + DispatchTableBuilder + dtb2; + dtb2.populate_dispatch_table(argmin_over_axis1_contig_temps_dispatch_table); + + using dpctl::tensor::kernels::ArgminOverAxis0TempsContigFactory; + DispatchTableBuilder + dtb3; + dtb3.populate_dispatch_table(argmin_over_axis0_contig_temps_dispatch_table); +} + +} // namespace impl + +void init_argmin(py::module_ m) +{ + using arrayT = dpctl::tensor::usm_ndarray; + using event_vecT = std::vector; + { + using impl::populate_argmin_over_axis_dispatch_tables; + populate_argmin_over_axis_dispatch_tables(); + using impl::argmin_over_axis0_contig_temps_dispatch_table; + using impl::argmin_over_axis1_contig_temps_dispatch_table; + using impl::argmin_over_axis_strided_temps_dispatch_table; + + auto argmin_pyapi = [&](const arrayT &src, int trailing_dims_to_reduce, + const arrayT &dst, sycl::queue &exec_q, + const event_vecT &depends = {}) { + using dpctl::tensor::py_internal::py_search_over_axis; + return py_search_over_axis( + src, trailing_dims_to_reduce, dst, exec_q, depends, + argmin_over_axis_strided_temps_dispatch_table, + argmin_over_axis0_contig_temps_dispatch_table, + argmin_over_axis1_contig_temps_dispatch_table); + }; + m.def("_argmin_over_axis", argmin_pyapi, "", py::arg("src"), + py::arg("trailing_dims_to_reduce"), py::arg("dst"), + py::arg("sycl_queue"), py::arg("depends") = py::list()); + } +} + +} // namespace py_internal +} // namespace tensor +} // namespace dpctl diff --git a/dpctl/tensor/libtensor/source/reductions/argmin.hpp b/dpctl/tensor/libtensor/source/reductions/argmin.hpp new file mode 100644 index 0000000000..ea6ef1931c --- /dev/null +++ b/dpctl/tensor/libtensor/source/reductions/argmin.hpp @@ -0,0 +1,41 @@ +//===-- ------------ Implementation of _tensor_impl module ----*-C++-*-/===// +// +// Data Parallel Control (dpctl) +// +// Copyright 2020-2023 Intel Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//===--------------------------------------------------------------------===// +/// +/// \file +/// This file defines functions of dpctl.tensor._tensor_impl extensions +//===--------------------------------------------------------------------===// + +#pragma once +#include + +namespace py = pybind11; + +namespace dpctl +{ +namespace tensor +{ +namespace py_internal +{ + +extern void init_argmin(py::module_ m); + +} // namespace py_internal +} // namespace tensor +} // namespace dpctl diff --git a/dpctl/tensor/libtensor/source/reductions/logsumexp.cpp b/dpctl/tensor/libtensor/source/reductions/logsumexp.cpp new file mode 100644 index 0000000000..e3b015a4e0 --- /dev/null +++ b/dpctl/tensor/libtensor/source/reductions/logsumexp.cpp @@ -0,0 +1,136 @@ +//===-- ------------ Implementation of _tensor_impl module ----*-C++-*-/===// +// +// Data Parallel Control (dpctl) +// +// Copyright 2020-2023 Intel Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//===--------------------------------------------------------------------===// +/// +/// \file +/// This file defines functions of dpctl.tensor._tensor_impl extensions +//===--------------------------------------------------------------------===// + +#include "dpctl4pybind11.hpp" +#include +#include +#include +#include +#include + +#include "kernels/reductions.hpp" +#include "reduction_over_axis.hpp" +#include "utils/type_dispatch.hpp" + +namespace py = pybind11; + +namespace dpctl +{ +namespace tensor +{ +namespace py_internal +{ + +namespace td_ns = dpctl::tensor::type_dispatch; + +namespace impl +{ + +using dpctl::tensor::kernels::reduction_strided_impl_fn_ptr; +static reduction_strided_impl_fn_ptr + logsumexp_over_axis_strided_temps_dispatch_table[td_ns::num_types] + [td_ns::num_types]; + +using dpctl::tensor::kernels::reduction_contig_impl_fn_ptr; +static reduction_contig_impl_fn_ptr + logsumexp_over_axis1_contig_temps_dispatch_table[td_ns::num_types] + [td_ns::num_types]; +static reduction_contig_impl_fn_ptr + logsumexp_over_axis0_contig_temps_dispatch_table[td_ns::num_types] + [td_ns::num_types]; + +void populate_logsumexp_over_axis_dispatch_tables(void) +{ + using dpctl::tensor::kernels::reduction_contig_impl_fn_ptr; + using dpctl::tensor::kernels::reduction_strided_impl_fn_ptr; + using namespace td_ns; + + using dpctl::tensor::kernels::LogSumExpOverAxisTempsStridedFactory; + DispatchTableBuilder + dtb1; + dtb1.populate_dispatch_table( + logsumexp_over_axis_strided_temps_dispatch_table); + + using dpctl::tensor::kernels::LogSumExpOverAxis1TempsContigFactory; + DispatchTableBuilder + dtb2; + dtb2.populate_dispatch_table( + logsumexp_over_axis1_contig_temps_dispatch_table); + + using dpctl::tensor::kernels::LogSumExpOverAxis0TempsContigFactory; + DispatchTableBuilder + dtb3; + dtb3.populate_dispatch_table( + logsumexp_over_axis0_contig_temps_dispatch_table); +} + +} // namespace impl + +void init_logsumexp(py::module_ m) +{ + using arrayT = dpctl::tensor::usm_ndarray; + using event_vecT = std::vector; + { + using impl::populate_logsumexp_over_axis_dispatch_tables; + populate_logsumexp_over_axis_dispatch_tables(); + using impl::logsumexp_over_axis0_contig_temps_dispatch_table; + using impl::logsumexp_over_axis1_contig_temps_dispatch_table; + using impl::logsumexp_over_axis_strided_temps_dispatch_table; + + using dpctl::tensor::kernels::reduction_contig_impl_fn_ptr; + using dpctl::tensor::kernels::reduction_strided_impl_fn_ptr; + + auto logsumexp_pyapi = [&](const arrayT &src, + int trailing_dims_to_reduce, + const arrayT &dst, sycl::queue &exec_q, + const event_vecT &depends = {}) { + using dpctl::tensor::py_internal::py_tree_reduction_over_axis; + return py_tree_reduction_over_axis( + src, trailing_dims_to_reduce, dst, exec_q, depends, + logsumexp_over_axis_strided_temps_dispatch_table, + logsumexp_over_axis0_contig_temps_dispatch_table, + logsumexp_over_axis1_contig_temps_dispatch_table); + }; + m.def("_logsumexp_over_axis", logsumexp_pyapi, "", py::arg("src"), + py::arg("trailing_dims_to_reduce"), py::arg("dst"), + py::arg("sycl_queue"), py::arg("depends") = py::list()); + + auto logsumexp_dtype_supported = [&](const py::dtype &input_dtype, + const py::dtype &output_dtype) { + using dpctl::tensor::py_internal::py_tree_reduction_dtype_supported; + return py_tree_reduction_dtype_supported( + input_dtype, output_dtype, + logsumexp_over_axis_strided_temps_dispatch_table); + }; + m.def("_logsumexp_over_axis_dtype_supported", logsumexp_dtype_supported, + "", py::arg("arg_dtype"), py::arg("out_dtype")); + } +} + +} // namespace py_internal +} // namespace tensor +} // namespace dpctl diff --git a/dpctl/tensor/libtensor/source/reductions/logsumexp.hpp b/dpctl/tensor/libtensor/source/reductions/logsumexp.hpp new file mode 100644 index 0000000000..46b2156f46 --- /dev/null +++ b/dpctl/tensor/libtensor/source/reductions/logsumexp.hpp @@ -0,0 +1,41 @@ +//===-- ------------ Implementation of _tensor_impl module ----*-C++-*-/===// +// +// Data Parallel Control (dpctl) +// +// Copyright 2020-2023 Intel Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//===--------------------------------------------------------------------===// +/// +/// \file +/// This file defines functions of dpctl.tensor._tensor_impl extensions +//===--------------------------------------------------------------------===// + +#pragma once +#include + +namespace py = pybind11; + +namespace dpctl +{ +namespace tensor +{ +namespace py_internal +{ + +extern void init_logsumexp(py::module_ m); + +} // namespace py_internal +} // namespace tensor +} // namespace dpctl diff --git a/dpctl/tensor/libtensor/source/reductions/max.cpp b/dpctl/tensor/libtensor/source/reductions/max.cpp new file mode 100644 index 0000000000..32c60b943b --- /dev/null +++ b/dpctl/tensor/libtensor/source/reductions/max.cpp @@ -0,0 +1,171 @@ +//===-- ------------ Implementation of _tensor_impl module ----*-C++-*-/===// +// +// Data Parallel Control (dpctl) +// +// Copyright 2020-2023 Intel Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//===--------------------------------------------------------------------===// +/// +/// \file +/// This file defines functions of dpctl.tensor._tensor_impl extensions +//===--------------------------------------------------------------------===// + +#include "dpctl4pybind11.hpp" +#include +#include +#include +#include +#include + +#include "kernels/reductions.hpp" +#include "utils/type_dispatch.hpp" + +#include "reduction_atomic_support.hpp" +#include "reduction_over_axis.hpp" + +namespace py = pybind11; + +namespace dpctl +{ +namespace tensor +{ +namespace py_internal +{ + +namespace td_ns = dpctl::tensor::type_dispatch; + +namespace impl +{ + +using dpctl::tensor::kernels::reduction_strided_impl_fn_ptr; +static reduction_strided_impl_fn_ptr + max_over_axis_strided_atomic_dispatch_table[td_ns::num_types] + [td_ns::num_types]; +static reduction_strided_impl_fn_ptr + max_over_axis_strided_temps_dispatch_table[td_ns::num_types] + [td_ns::num_types]; + +using dpctl::tensor::kernels::reduction_contig_impl_fn_ptr; +static reduction_contig_impl_fn_ptr + max_over_axis1_contig_atomic_dispatch_table[td_ns::num_types] + [td_ns::num_types]; +static reduction_contig_impl_fn_ptr + max_over_axis0_contig_atomic_dispatch_table[td_ns::num_types] + [td_ns::num_types]; +static reduction_contig_impl_fn_ptr + max_over_axis1_contig_temps_dispatch_table[td_ns::num_types] + [td_ns::num_types]; +static reduction_contig_impl_fn_ptr + max_over_axis0_contig_temps_dispatch_table[td_ns::num_types] + [td_ns::num_types]; + +void populate_max_over_axis_dispatch_tables(void) +{ + using td_ns::DispatchTableBuilder; + + using dpctl::tensor::kernels::MaxOverAxisAtomicStridedFactory; + DispatchTableBuilder + dtb1; + dtb1.populate_dispatch_table(max_over_axis_strided_atomic_dispatch_table); + + using dpctl::tensor::kernels::MaxOverAxisTempsStridedFactory; + DispatchTableBuilder + dtb2; + dtb2.populate_dispatch_table(max_over_axis_strided_temps_dispatch_table); + + using dpctl::tensor::kernels::MaxOverAxis1AtomicContigFactory; + DispatchTableBuilder + dtb3; + dtb3.populate_dispatch_table(max_over_axis1_contig_atomic_dispatch_table); + + using dpctl::tensor::kernels::MaxOverAxis0AtomicContigFactory; + DispatchTableBuilder + dtb4; + dtb4.populate_dispatch_table(max_over_axis0_contig_atomic_dispatch_table); + + using dpctl::tensor::kernels::MaxOverAxis1TempsContigFactory; + DispatchTableBuilder + dtb5; + dtb5.populate_dispatch_table(max_over_axis1_contig_temps_dispatch_table); + + using dpctl::tensor::kernels::MaxOverAxis0TempsContigFactory; + DispatchTableBuilder + dtb6; + dtb6.populate_dispatch_table(max_over_axis0_contig_temps_dispatch_table); +} + +using atomic_support::atomic_support_fn_ptr_t; +static atomic_support_fn_ptr_t max_atomic_support_vector[td_ns::num_types]; + +void populate_max_atomic_support_dispatch_vector(void) +{ + using td_ns::DispatchVectorBuilder; + + using atomic_support::MaxAtomicSupportFactory; + DispatchVectorBuilder + dvb; + dvb.populate_dispatch_vector(max_atomic_support_vector); +} + +} // namespace impl + +void init_max(py::module_ m) +{ + using arrayT = dpctl::tensor::usm_ndarray; + using event_vecT = std::vector; + { + using impl::populate_max_over_axis_dispatch_tables; + populate_max_over_axis_dispatch_tables(); + using impl::max_over_axis0_contig_atomic_dispatch_table; + using impl::max_over_axis0_contig_temps_dispatch_table; + using impl::max_over_axis1_contig_atomic_dispatch_table; + using impl::max_over_axis1_contig_temps_dispatch_table; + using impl::max_over_axis_strided_atomic_dispatch_table; + using impl::max_over_axis_strided_temps_dispatch_table; + + using impl::populate_max_atomic_support_dispatch_vector; + populate_max_atomic_support_dispatch_vector(); + using impl::max_atomic_support_vector; + + auto max_pyapi = [&](const arrayT &src, int trailing_dims_to_reduce, + const arrayT &dst, sycl::queue &exec_q, + const event_vecT &depends = {}) { + using dpctl::tensor::py_internal::py_reduction_over_axis; + return py_reduction_over_axis( + src, trailing_dims_to_reduce, dst, exec_q, depends, + max_over_axis_strided_atomic_dispatch_table, + max_over_axis0_contig_atomic_dispatch_table, + max_over_axis1_contig_atomic_dispatch_table, + max_over_axis_strided_temps_dispatch_table, + max_over_axis0_contig_temps_dispatch_table, + max_over_axis1_contig_temps_dispatch_table, + max_atomic_support_vector); + }; + m.def("_max_over_axis", max_pyapi, "", py::arg("src"), + py::arg("trailing_dims_to_reduce"), py::arg("dst"), + py::arg("sycl_queue"), py::arg("depends") = py::list()); + } +} + +} // namespace py_internal +} // namespace tensor +} // namespace dpctl diff --git a/dpctl/tensor/libtensor/source/reductions/max.hpp b/dpctl/tensor/libtensor/source/reductions/max.hpp new file mode 100644 index 0000000000..05a31fc1fb --- /dev/null +++ b/dpctl/tensor/libtensor/source/reductions/max.hpp @@ -0,0 +1,41 @@ +//===-- ------------ Implementation of _tensor_impl module ----*-C++-*-/===// +// +// Data Parallel Control (dpctl) +// +// Copyright 2020-2023 Intel Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//===--------------------------------------------------------------------===// +/// +/// \file +/// This file defines functions of dpctl.tensor._tensor_impl extensions +//===--------------------------------------------------------------------===// + +#pragma once +#include + +namespace py = pybind11; + +namespace dpctl +{ +namespace tensor +{ +namespace py_internal +{ + +extern void init_max(py::module_ m); + +} // namespace py_internal +} // namespace tensor +} // namespace dpctl diff --git a/dpctl/tensor/libtensor/source/reductions/min.cpp b/dpctl/tensor/libtensor/source/reductions/min.cpp new file mode 100644 index 0000000000..de1a81387d --- /dev/null +++ b/dpctl/tensor/libtensor/source/reductions/min.cpp @@ -0,0 +1,173 @@ +//===-- ------------ Implementation of _tensor_impl module ----*-C++-*-/===// +// +// Data Parallel Control (dpctl) +// +// Copyright 2020-2023 Intel Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//===--------------------------------------------------------------------===// +/// +/// \file +/// This file defines functions of dpctl.tensor._tensor_impl extensions +//===--------------------------------------------------------------------===// + +#include "dpctl4pybind11.hpp" +#include +#include +#include +#include +#include + +#include "kernels/reductions.hpp" +#include "utils/type_dispatch.hpp" + +#include "reduction_atomic_support.hpp" +#include "reduction_over_axis.hpp" + +namespace py = pybind11; + +namespace dpctl +{ +namespace tensor +{ +namespace py_internal +{ + +namespace td_ns = dpctl::tensor::type_dispatch; + +namespace impl +{ + +using dpctl::tensor::kernels::reduction_strided_impl_fn_ptr; +static reduction_strided_impl_fn_ptr + min_over_axis_strided_atomic_dispatch_table[td_ns::num_types] + [td_ns::num_types]; +static reduction_strided_impl_fn_ptr + min_over_axis_strided_temps_dispatch_table[td_ns::num_types] + [td_ns::num_types]; + +using dpctl::tensor::kernels::reduction_contig_impl_fn_ptr; +static reduction_contig_impl_fn_ptr + min_over_axis1_contig_atomic_dispatch_table[td_ns::num_types] + [td_ns::num_types]; +static reduction_contig_impl_fn_ptr + min_over_axis0_contig_atomic_dispatch_table[td_ns::num_types] + [td_ns::num_types]; +static reduction_contig_impl_fn_ptr + min_over_axis1_contig_temps_dispatch_table[td_ns::num_types] + [td_ns::num_types]; +static reduction_contig_impl_fn_ptr + min_over_axis0_contig_temps_dispatch_table[td_ns::num_types] + [td_ns::num_types]; + +void populate_min_over_axis_dispatch_tables(void) +{ + using dpctl::tensor::kernels::reduction_contig_impl_fn_ptr; + using dpctl::tensor::kernels::reduction_strided_impl_fn_ptr; + using td_ns::DispatchTableBuilder; + + using dpctl::tensor::kernels::MinOverAxisAtomicStridedFactory; + DispatchTableBuilder + dtb1; + dtb1.populate_dispatch_table(min_over_axis_strided_atomic_dispatch_table); + + using dpctl::tensor::kernels::MinOverAxisTempsStridedFactory; + DispatchTableBuilder + dtb2; + dtb2.populate_dispatch_table(min_over_axis_strided_temps_dispatch_table); + + using dpctl::tensor::kernels::MinOverAxis1AtomicContigFactory; + DispatchTableBuilder + dtb3; + dtb3.populate_dispatch_table(min_over_axis1_contig_atomic_dispatch_table); + + using dpctl::tensor::kernels::MinOverAxis0AtomicContigFactory; + DispatchTableBuilder + dtb4; + dtb4.populate_dispatch_table(min_over_axis0_contig_atomic_dispatch_table); + + using dpctl::tensor::kernels::MinOverAxis1TempsContigFactory; + DispatchTableBuilder + dtb5; + dtb5.populate_dispatch_table(min_over_axis1_contig_temps_dispatch_table); + + using dpctl::tensor::kernels::MinOverAxis0TempsContigFactory; + DispatchTableBuilder + dtb6; + dtb6.populate_dispatch_table(min_over_axis0_contig_temps_dispatch_table); +} + +using atomic_support::atomic_support_fn_ptr_t; +static atomic_support_fn_ptr_t min_atomic_support_vector[td_ns::num_types]; + +void populate_min_atomic_support_dispatch_vector(void) +{ + using td_ns::DispatchVectorBuilder; + + using atomic_support::MinAtomicSupportFactory; + DispatchVectorBuilder + dvb; + dvb.populate_dispatch_vector(min_atomic_support_vector); +} + +} // namespace impl + +void init_min(py::module_ m) +{ + using arrayT = dpctl::tensor::usm_ndarray; + using event_vecT = std::vector; + { + using impl::populate_min_over_axis_dispatch_tables; + populate_min_over_axis_dispatch_tables(); + using impl::min_over_axis0_contig_atomic_dispatch_table; + using impl::min_over_axis0_contig_temps_dispatch_table; + using impl::min_over_axis1_contig_atomic_dispatch_table; + using impl::min_over_axis1_contig_temps_dispatch_table; + using impl::min_over_axis_strided_atomic_dispatch_table; + using impl::min_over_axis_strided_temps_dispatch_table; + + using impl::populate_min_atomic_support_dispatch_vector; + populate_min_atomic_support_dispatch_vector(); + using impl::min_atomic_support_vector; + + auto min_pyapi = [&](const arrayT &src, int trailing_dims_to_reduce, + const arrayT &dst, sycl::queue &exec_q, + const event_vecT &depends = {}) { + using dpctl::tensor::py_internal::py_reduction_over_axis; + return py_reduction_over_axis( + src, trailing_dims_to_reduce, dst, exec_q, depends, + min_over_axis_strided_atomic_dispatch_table, + min_over_axis0_contig_atomic_dispatch_table, + min_over_axis1_contig_atomic_dispatch_table, + min_over_axis_strided_temps_dispatch_table, + min_over_axis0_contig_temps_dispatch_table, + min_over_axis1_contig_temps_dispatch_table, + min_atomic_support_vector); + }; + m.def("_min_over_axis", min_pyapi, "", py::arg("src"), + py::arg("trailing_dims_to_reduce"), py::arg("dst"), + py::arg("sycl_queue"), py::arg("depends") = py::list()); + } +} + +} // namespace py_internal +} // namespace tensor +} // namespace dpctl diff --git a/dpctl/tensor/libtensor/source/reductions/min.hpp b/dpctl/tensor/libtensor/source/reductions/min.hpp new file mode 100644 index 0000000000..cad94c7533 --- /dev/null +++ b/dpctl/tensor/libtensor/source/reductions/min.hpp @@ -0,0 +1,41 @@ +//===-- ------------ Implementation of _tensor_impl module ----*-C++-*-/===// +// +// Data Parallel Control (dpctl) +// +// Copyright 2020-2023 Intel Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//===--------------------------------------------------------------------===// +/// +/// \file +/// This file defines functions of dpctl.tensor._tensor_impl extensions +//===--------------------------------------------------------------------===// + +#pragma once +#include + +namespace py = pybind11; + +namespace dpctl +{ +namespace tensor +{ +namespace py_internal +{ + +extern void init_min(py::module_ m); + +} // namespace py_internal +} // namespace tensor +} // namespace dpctl diff --git a/dpctl/tensor/libtensor/source/reductions/prod.cpp b/dpctl/tensor/libtensor/source/reductions/prod.cpp new file mode 100644 index 0000000000..a90d04304a --- /dev/null +++ b/dpctl/tensor/libtensor/source/reductions/prod.cpp @@ -0,0 +1,187 @@ +//===-- ------------ Implementation of _tensor_impl module ----*-C++-*-/===// +// +// Data Parallel Control (dpctl) +// +// Copyright 2020-2023 Intel Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//===--------------------------------------------------------------------===// +/// +/// \file +/// This file defines functions of dpctl.tensor._tensor_impl extensions +//===--------------------------------------------------------------------===// + +#include "dpctl4pybind11.hpp" +#include +#include +#include +#include +#include + +#include "kernels/reductions.hpp" +#include "utils/type_dispatch.hpp" + +#include "reduction_atomic_support.hpp" +#include "reduction_over_axis.hpp" + +namespace py = pybind11; + +namespace dpctl +{ +namespace tensor +{ +namespace py_internal +{ + +namespace td_ns = dpctl::tensor::type_dispatch; + +namespace impl +{ + +using dpctl::tensor::kernels::reduction_strided_impl_fn_ptr; +static reduction_strided_impl_fn_ptr + prod_over_axis_strided_atomic_dispatch_table[td_ns::num_types] + [td_ns::num_types]; +static reduction_strided_impl_fn_ptr + prod_over_axis_strided_temps_dispatch_table[td_ns::num_types] + [td_ns::num_types]; + +using dpctl::tensor::kernels::reduction_contig_impl_fn_ptr; +static reduction_contig_impl_fn_ptr + prod_over_axis1_contig_atomic_dispatch_table[td_ns::num_types] + [td_ns::num_types]; +static reduction_contig_impl_fn_ptr + prod_over_axis0_contig_atomic_dispatch_table[td_ns::num_types] + [td_ns::num_types]; +static reduction_contig_impl_fn_ptr + prod_over_axis1_contig_temps_dispatch_table[td_ns::num_types] + [td_ns::num_types]; +static reduction_contig_impl_fn_ptr + prod_over_axis0_contig_temps_dispatch_table[td_ns::num_types] + [td_ns::num_types]; + +void populate_prod_over_axis_dispatch_tables(void) +{ + using dpctl::tensor::kernels::reduction_contig_impl_fn_ptr; + using dpctl::tensor::kernels::reduction_strided_impl_fn_ptr; + using namespace td_ns; + + using dpctl::tensor::kernels::ProductOverAxisAtomicStridedFactory; + DispatchTableBuilder + dtb1; + dtb1.populate_dispatch_table(prod_over_axis_strided_atomic_dispatch_table); + + using dpctl::tensor::kernels::ProductOverAxisTempsStridedFactory; + DispatchTableBuilder + dtb2; + dtb2.populate_dispatch_table(prod_over_axis_strided_temps_dispatch_table); + + using dpctl::tensor::kernels::ProductOverAxis1AtomicContigFactory; + DispatchTableBuilder + dtb3; + dtb3.populate_dispatch_table(prod_over_axis1_contig_atomic_dispatch_table); + + using dpctl::tensor::kernels::ProductOverAxis0AtomicContigFactory; + DispatchTableBuilder + dtb4; + dtb4.populate_dispatch_table(prod_over_axis0_contig_atomic_dispatch_table); + + using dpctl::tensor::kernels::ProductOverAxis1TempsContigFactory; + DispatchTableBuilder + dtb5; + dtb5.populate_dispatch_table(prod_over_axis1_contig_temps_dispatch_table); + + using dpctl::tensor::kernels::ProductOverAxis0TempsContigFactory; + DispatchTableBuilder + dtb6; + dtb6.populate_dispatch_table(prod_over_axis0_contig_temps_dispatch_table); +} + +using atomic_support::atomic_support_fn_ptr_t; +static atomic_support_fn_ptr_t prod_atomic_support_vector[td_ns::num_types]; + +void populate_prod_atomic_support_dispatch_vector(void) +{ + using td_ns::DispatchVectorBuilder; + + using atomic_support::ProductAtomicSupportFactory; + DispatchVectorBuilder + dvb; + dvb.populate_dispatch_vector(prod_atomic_support_vector); +} + +} // namespace impl + +void init_prod(py::module_ m) +{ + using arrayT = dpctl::tensor::usm_ndarray; + using event_vecT = std::vector; + { + using impl::populate_prod_over_axis_dispatch_tables; + populate_prod_over_axis_dispatch_tables(); + using impl::prod_over_axis0_contig_atomic_dispatch_table; + using impl::prod_over_axis0_contig_temps_dispatch_table; + using impl::prod_over_axis1_contig_atomic_dispatch_table; + using impl::prod_over_axis1_contig_temps_dispatch_table; + using impl::prod_over_axis_strided_atomic_dispatch_table; + using impl::prod_over_axis_strided_temps_dispatch_table; + + using impl::populate_prod_atomic_support_dispatch_vector; + populate_prod_atomic_support_dispatch_vector(); + using impl::prod_atomic_support_vector; + + auto prod_pyapi = [&](const arrayT &src, int trailing_dims_to_reduce, + const arrayT &dst, sycl::queue &exec_q, + const event_vecT &depends = {}) { + using dpctl::tensor::py_internal::py_reduction_over_axis; + return py_reduction_over_axis( + src, trailing_dims_to_reduce, dst, exec_q, depends, + prod_over_axis_strided_atomic_dispatch_table, + prod_over_axis0_contig_atomic_dispatch_table, + prod_over_axis1_contig_atomic_dispatch_table, + prod_over_axis_strided_temps_dispatch_table, + prod_over_axis0_contig_temps_dispatch_table, + prod_over_axis1_contig_temps_dispatch_table, + prod_atomic_support_vector); + }; + m.def("_prod_over_axis", prod_pyapi, "", py::arg("src"), + py::arg("trailing_dims_to_reduce"), py::arg("dst"), + py::arg("sycl_queue"), py::arg("depends") = py::list()); + + auto prod_dtype_supported = + [&](const py::dtype &input_dtype, const py::dtype &output_dtype, + const std::string &dst_usm_type, sycl::queue &q) { + using dpctl::tensor::py_internal::py_reduction_dtype_supported; + return py_reduction_dtype_supported( + input_dtype, output_dtype, dst_usm_type, q, + prod_over_axis_strided_atomic_dispatch_table, + prod_over_axis_strided_temps_dispatch_table, + prod_atomic_support_vector); + }; + m.def("_prod_over_axis_dtype_supported", prod_dtype_supported, "", + py::arg("arg_dtype"), py::arg("out_dtype"), + py::arg("dst_usm_type"), py::arg("sycl_queue")); + } +} + +} // namespace py_internal +} // namespace tensor +} // namespace dpctl diff --git a/dpctl/tensor/libtensor/source/reductions/prod.hpp b/dpctl/tensor/libtensor/source/reductions/prod.hpp new file mode 100644 index 0000000000..026e7d8923 --- /dev/null +++ b/dpctl/tensor/libtensor/source/reductions/prod.hpp @@ -0,0 +1,41 @@ +//===-- ------------ Implementation of _tensor_impl module ----*-C++-*-/===// +// +// Data Parallel Control (dpctl) +// +// Copyright 2020-2023 Intel Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//===--------------------------------------------------------------------===// +/// +/// \file +/// This file defines functions of dpctl.tensor._tensor_impl extensions +//===--------------------------------------------------------------------===// + +#pragma once +#include + +namespace py = pybind11; + +namespace dpctl +{ +namespace tensor +{ +namespace py_internal +{ + +extern void init_prod(py::module_ m); + +} // namespace py_internal +} // namespace tensor +} // namespace dpctl diff --git a/dpctl/tensor/libtensor/source/reductions/reduce_hypot.cpp b/dpctl/tensor/libtensor/source/reductions/reduce_hypot.cpp new file mode 100644 index 0000000000..c7313930b4 --- /dev/null +++ b/dpctl/tensor/libtensor/source/reductions/reduce_hypot.cpp @@ -0,0 +1,132 @@ +//===-- ------------ Implementation of _tensor_impl module ----*-C++-*-/===// +// +// Data Parallel Control (dpctl) +// +// Copyright 2020-2023 Intel Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//===--------------------------------------------------------------------===// +/// +/// \file +/// This file defines functions of dpctl.tensor._tensor_impl extensions +//===--------------------------------------------------------------------===// + +#include "dpctl4pybind11.hpp" +#include +#include +#include +#include +#include + +#include "kernels/reductions.hpp" +#include "reduction_over_axis.hpp" +#include "utils/type_dispatch.hpp" + +namespace py = pybind11; + +namespace dpctl +{ +namespace tensor +{ +namespace py_internal +{ + +namespace td_ns = dpctl::tensor::type_dispatch; + +namespace impl +{ + +using dpctl::tensor::kernels::reduction_strided_impl_fn_ptr; +static reduction_strided_impl_fn_ptr + hypot_over_axis_strided_temps_dispatch_table[td_ns::num_types] + [td_ns::num_types]; + +using dpctl::tensor::kernels::reduction_contig_impl_fn_ptr; +static reduction_contig_impl_fn_ptr + hypot_over_axis1_contig_temps_dispatch_table[td_ns::num_types] + [td_ns::num_types]; +static reduction_contig_impl_fn_ptr + hypot_over_axis0_contig_temps_dispatch_table[td_ns::num_types] + [td_ns::num_types]; + +void populate_hypot_over_axis_dispatch_tables(void) +{ + using dpctl::tensor::kernels::reduction_contig_impl_fn_ptr; + using dpctl::tensor::kernels::reduction_strided_impl_fn_ptr; + using namespace td_ns; + + using dpctl::tensor::kernels::HypotOverAxisTempsStridedFactory; + DispatchTableBuilder + dtb1; + dtb1.populate_dispatch_table(hypot_over_axis_strided_temps_dispatch_table); + + using dpctl::tensor::kernels::HypotOverAxis1TempsContigFactory; + DispatchTableBuilder + dtb2; + dtb2.populate_dispatch_table(hypot_over_axis1_contig_temps_dispatch_table); + + using dpctl::tensor::kernels::HypotOverAxis0TempsContigFactory; + DispatchTableBuilder + dtb3; + dtb3.populate_dispatch_table(hypot_over_axis0_contig_temps_dispatch_table); +} + +} // namespace impl + +void init_reduce_hypot(py::module_ m) +{ + using arrayT = dpctl::tensor::usm_ndarray; + using event_vecT = std::vector; + { + using impl::populate_hypot_over_axis_dispatch_tables; + populate_hypot_over_axis_dispatch_tables(); + using impl::hypot_over_axis0_contig_temps_dispatch_table; + using impl::hypot_over_axis1_contig_temps_dispatch_table; + using impl::hypot_over_axis_strided_temps_dispatch_table; + + using dpctl::tensor::kernels::reduction_contig_impl_fn_ptr; + using dpctl::tensor::kernels::reduction_strided_impl_fn_ptr; + + auto hypot_pyapi = [&](const arrayT &src, int trailing_dims_to_reduce, + const arrayT &dst, sycl::queue &exec_q, + const event_vecT &depends = {}) { + using dpctl::tensor::py_internal::py_tree_reduction_over_axis; + return py_tree_reduction_over_axis( + src, trailing_dims_to_reduce, dst, exec_q, depends, + hypot_over_axis_strided_temps_dispatch_table, + hypot_over_axis0_contig_temps_dispatch_table, + hypot_over_axis1_contig_temps_dispatch_table); + }; + m.def("_hypot_over_axis", hypot_pyapi, "", py::arg("src"), + py::arg("trailing_dims_to_reduce"), py::arg("dst"), + py::arg("sycl_queue"), py::arg("depends") = py::list()); + + auto hypot_dtype_supported = [&](const py::dtype &input_dtype, + const py::dtype &output_dtype) { + using dpctl::tensor::py_internal::py_tree_reduction_dtype_supported; + return py_tree_reduction_dtype_supported( + input_dtype, output_dtype, + hypot_over_axis_strided_temps_dispatch_table); + }; + m.def("_hypot_over_axis_dtype_supported", hypot_dtype_supported, "", + py::arg("arg_dtype"), py::arg("out_dtype")); + } +} + +} // namespace py_internal +} // namespace tensor +} // namespace dpctl diff --git a/dpctl/tensor/libtensor/source/reductions/reduce_hypot.hpp b/dpctl/tensor/libtensor/source/reductions/reduce_hypot.hpp new file mode 100644 index 0000000000..92b7fac363 --- /dev/null +++ b/dpctl/tensor/libtensor/source/reductions/reduce_hypot.hpp @@ -0,0 +1,41 @@ +//===-- ------------ Implementation of _tensor_impl module ----*-C++-*-/===// +// +// Data Parallel Control (dpctl) +// +// Copyright 2020-2023 Intel Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//===--------------------------------------------------------------------===// +/// +/// \file +/// This file defines functions of dpctl.tensor._tensor_impl extensions +//===--------------------------------------------------------------------===// + +#pragma once +#include + +namespace py = pybind11; + +namespace dpctl +{ +namespace tensor +{ +namespace py_internal +{ + +extern void init_reduce_hypot(py::module_ m); + +} // namespace py_internal +} // namespace tensor +} // namespace dpctl diff --git a/dpctl/tensor/libtensor/source/reductions/reduction_atomic_support.hpp b/dpctl/tensor/libtensor/source/reductions/reduction_atomic_support.hpp new file mode 100644 index 0000000000..695f4b73d0 --- /dev/null +++ b/dpctl/tensor/libtensor/source/reductions/reduction_atomic_support.hpp @@ -0,0 +1,143 @@ +//===-- ------------ Implementation of _tensor_impl module ----*-C++-*-/===// +// +// Data Parallel Control (dpctl) +// +// Copyright 2020-2023 Intel Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//===--------------------------------------------------------------------===// +/// +/// \file +/// This file defines functions of dpctl.tensor._tensor_impl extensions +//===--------------------------------------------------------------------===// + +#pragma once +#include +#include +#include + +#include "utils/type_utils.hpp" + +namespace dpctl +{ +namespace tensor +{ +namespace py_internal +{ +namespace atomic_support +{ + +typedef bool (*atomic_support_fn_ptr_t)(const sycl::queue &, sycl::usm::alloc); + +/*! @brief Function which returns a constant value for atomic support */ +template +bool fixed_decision(const sycl::queue &, sycl::usm::alloc) +{ + return return_value; +} + +/*! @brief Template for querying atomic support for a type on a device */ +template +bool check_atomic_support(const sycl::queue &exec_q, + sycl::usm::alloc usm_alloc_type) +{ + constexpr bool atomic32 = (sizeof(T) == 4); + constexpr bool atomic64 = (sizeof(T) == 8); + using dpctl::tensor::type_utils::is_complex; + if constexpr ((!atomic32 && !atomic64) || is_complex::value) { + return fixed_decision(exec_q, usm_alloc_type); + } + else { + bool supports_atomics = false; + const sycl::device &dev = exec_q.get_device(); + if constexpr (atomic64) { + if (!dev.has(sycl::aspect::atomic64)) { + return false; + } + } + switch (usm_alloc_type) { + case sycl::usm::alloc::shared: + supports_atomics = + dev.has(sycl::aspect::usm_atomic_shared_allocations); + break; + case sycl::usm::alloc::host: + supports_atomics = + dev.has(sycl::aspect::usm_atomic_host_allocations); + break; + case sycl::usm::alloc::device: + supports_atomics = true; + break; + default: + supports_atomics = false; + } + return supports_atomics; + } +} + +template struct ArithmeticAtomicSupportFactory +{ + fnT get() + { + using dpctl::tensor::type_utils::is_complex; + if constexpr (std::is_floating_point_v || + std::is_same_v || is_complex::value) + { + // for real- and complex- floating point types, tree reduction has + // better round-off accumulation properties (round-off error is + // proportional to the log2(reduction_size), while naive elementwise + // summation used by atomic implementation has round-off error + // growing proportional to the reduction_size.), hence reduction + // over floating point types should always use tree_reduction + // algorithm, even though atomic implementation may be applicable + return fixed_decision; + } + else { + return check_atomic_support; + } + } +}; + +template struct MinMaxAtomicSupportFactory +{ + fnT get() + { + return check_atomic_support; + } +}; + +template +struct MaxAtomicSupportFactory : public ArithmeticAtomicSupportFactory +{ +}; + +template +struct MinAtomicSupportFactory : public ArithmeticAtomicSupportFactory +{ +}; + +template +struct SumAtomicSupportFactory : public ArithmeticAtomicSupportFactory +{ +}; + +template +struct ProductAtomicSupportFactory + : public ArithmeticAtomicSupportFactory +{ +}; + +} // namespace atomic_support +} // namespace py_internal +} // namespace tensor +} // namespace dpctl diff --git a/dpctl/tensor/libtensor/source/reductions/reduction_common.cpp b/dpctl/tensor/libtensor/source/reductions/reduction_common.cpp new file mode 100644 index 0000000000..99edf663ad --- /dev/null +++ b/dpctl/tensor/libtensor/source/reductions/reduction_common.cpp @@ -0,0 +1,60 @@ +//===-- ------------ Implementation of _tensor_impl module ----*-C++-*-/===// +// +// Data Parallel Control (dpctl) +// +// Copyright 2020-2023 Intel Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//===--------------------------------------------------------------------===// +/// +/// \file +/// This file defines functions of dpctl.tensor._tensor_impl extensions +//===--------------------------------------------------------------------===// + +#include + +#include "argmax.hpp" +#include "argmin.hpp" +#include "logsumexp.hpp" +#include "max.hpp" +#include "min.hpp" +#include "prod.hpp" +#include "reduce_hypot.hpp" +#include "sum.hpp" + +namespace py = pybind11; + +namespace dpctl +{ +namespace tensor +{ +namespace py_internal +{ + +/*! @brief Add reduction functions to Python module */ +void init_reduction_functions(py::module_ m) +{ + init_argmax(m); + init_argmin(m); + init_logsumexp(m); + init_max(m); + init_min(m); + init_prod(m); + init_reduce_hypot(m); + init_sum(m); +} + +} // namespace py_internal +} // namespace tensor +} // namespace dpctl diff --git a/dpctl/tensor/libtensor/source/reductions/reduction_common.hpp b/dpctl/tensor/libtensor/source/reductions/reduction_common.hpp new file mode 100644 index 0000000000..61c992364a --- /dev/null +++ b/dpctl/tensor/libtensor/source/reductions/reduction_common.hpp @@ -0,0 +1,41 @@ +//===-- ------------ Implementation of _tensor_impl module ----*-C++-*-/===// +// +// Data Parallel Control (dpctl) +// +// Copyright 2020-2023 Intel Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//===--------------------------------------------------------------------===// +/// +/// \file +/// This file defines functions of dpctl.tensor._tensor_impl extensions +//===--------------------------------------------------------------------===// + +#pragma once +#include + +namespace py = pybind11; + +namespace dpctl +{ +namespace tensor +{ +namespace py_internal +{ + +extern void init_reduction_functions(py::module_); + +} // namespace py_internal +} // namespace tensor +} // namespace dpctl diff --git a/dpctl/tensor/libtensor/source/reductions/reduction_over_axis.hpp b/dpctl/tensor/libtensor/source/reductions/reduction_over_axis.hpp new file mode 100644 index 0000000000..da8da0938d --- /dev/null +++ b/dpctl/tensor/libtensor/source/reductions/reduction_over_axis.hpp @@ -0,0 +1,1095 @@ +//===----------- Implementation of _tensor_impl module ---------*-C++-*-/===// +// +// Data Parallel Control (dpctl) +// +// Copyright 2020-2023 Intel Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//===----------------------------------------------------------------------===// +/// +/// \file +/// This file defines functions of dpctl.tensor._tensor_impl extensions, +/// specifically functions for reductions. +//===----------------------------------------------------------------------===// + +#pragma once + +#include +#include +#include +#include +#include +#include + +#include "dpctl4pybind11.hpp" +#include +#include +#include + +#include "kernels/reductions.hpp" +#include "simplify_iteration_space.hpp" +#include "utils/memory_overlap.hpp" +#include "utils/offset_utils.hpp" +#include "utils/type_dispatch.hpp" + +namespace dpctl +{ +namespace tensor +{ +namespace py_internal +{ + +/* ====================== dtype supported ======================== */ + +/*! @brief Template implementing Python API for querying type support by + * reduction which may support atomics */ +template +bool py_reduction_dtype_supported( + const py::dtype &input_dtype, + const py::dtype &output_dtype, + const std::string &dst_usm_type, + sycl::queue &q, + const fnT &atomic_dispatch_table, + const fnT &temps_dispatch_table, + const CheckAtomicSupportFnT &check_atomic_support) +{ + int arg_tn = + input_dtype.num(); // NumPy type numbers are the same as in dpctl + int out_tn = + output_dtype.num(); // NumPy type numbers are the same as in dpctl + int arg_typeid = -1; + int out_typeid = -1; + + auto array_types = td_ns::usm_ndarray_types(); + + try { + arg_typeid = array_types.typenum_to_lookup_id(arg_tn); + out_typeid = array_types.typenum_to_lookup_id(out_tn); + } catch (const std::exception &e) { + throw py::value_error(e.what()); + } + + if (arg_typeid < 0 || arg_typeid >= td_ns::num_types || out_typeid < 0 || + out_typeid >= td_ns::num_types) + { + throw std::runtime_error("Reduction type support check: lookup failed"); + } + + // remove_all_extents gets underlying type of table + using fn_ptrT = typename std::remove_all_extents::type; + fn_ptrT fn = nullptr; + + sycl::usm::alloc kind = sycl::usm::alloc::unknown; + + if (dst_usm_type == "device") { + kind = sycl::usm::alloc::device; + } + else if (dst_usm_type == "shared") { + kind = sycl::usm::alloc::shared; + } + else if (dst_usm_type == "host") { + kind = sycl::usm::alloc::host; + } + else { + throw py::value_error("Unrecognized `dst_usm_type` argument."); + } + + bool supports_atomics = check_atomic_support[out_typeid](q, kind); + + if (supports_atomics) { + fn = atomic_dispatch_table[arg_typeid][out_typeid]; + } + + if (fn == nullptr) { + // use slower reduction implementation using temporaries + fn = temps_dispatch_table[arg_typeid][out_typeid]; + } + + return (fn != nullptr); +} + +/*! @brief Template implementing Python API for querying type support by tree + * reduction */ +template +bool py_tree_reduction_dtype_supported(const py::dtype &input_dtype, + const py::dtype &output_dtype, + const fnT &temps_dispatch_table) +{ + int arg_tn = + input_dtype.num(); // NumPy type numbers are the same as in dpctl + int out_tn = + output_dtype.num(); // NumPy type numbers are the same as in dpctl + int arg_typeid = -1; + int out_typeid = -1; + + auto array_types = td_ns::usm_ndarray_types(); + + try { + arg_typeid = array_types.typenum_to_lookup_id(arg_tn); + out_typeid = array_types.typenum_to_lookup_id(out_tn); + } catch (const std::exception &e) { + throw py::value_error(e.what()); + } + + if (arg_typeid < 0 || arg_typeid >= td_ns::num_types || out_typeid < 0 || + out_typeid >= td_ns::num_types) + { + throw std::runtime_error("Reduction type support check: lookup failed"); + } + + auto fn = temps_dispatch_table[arg_typeid][out_typeid]; + + return (fn != nullptr); +} + +/* ==================== Generic reductions ====================== */ + +/*! @brief Template implementing Python API for reduction over axis which may + * support atomics */ +template +std::pair py_reduction_over_axis( + const dpctl::tensor::usm_ndarray &src, + int trailing_dims_to_reduce, // comp over this many trailing indexes + const dpctl::tensor::usm_ndarray &dst, + sycl::queue &exec_q, + const std::vector &depends, + const strided_fnT &atomic_dispatch_table, + const contig_fnT &axis0_atomic_dispatch_table, + const contig_fnT &axis1_atomic_dispatch_table, + const strided_fnT &temps_dispatch_table, + const contig_fnT &axis0_temps_dispatch_table, + const contig_fnT &axis1_temps_dispatch_table, + const SupportAtomicFnT &check_atomic_support) +{ + int src_nd = src.get_ndim(); + int iteration_nd = src_nd - trailing_dims_to_reduce; + if (trailing_dims_to_reduce <= 0 || iteration_nd < 0) { + throw py::value_error("Trailing_dim_to_reduce must be positive, but no " + "greater than rank of the array being reduced"); + } + + int dst_nd = dst.get_ndim(); + if (dst_nd != iteration_nd) { + throw py::value_error("Destination array rank does not match input " + "array rank and number of reduced dimensions"); + } + + const py::ssize_t *src_shape_ptr = src.get_shape_raw(); + const py::ssize_t *dst_shape_ptr = dst.get_shape_raw(); + + bool same_shapes = true; + for (int i = 0; same_shapes && (i < dst_nd); ++i) { + same_shapes = same_shapes && (src_shape_ptr[i] == dst_shape_ptr[i]); + } + + if (!same_shapes) { + throw py::value_error("Destination shape does not match unreduced " + "dimensions of the input shape"); + } + + if (!dpctl::utils::queues_are_compatible(exec_q, {src, dst})) { + throw py::value_error( + "Execution queue is not compatible with allocation queues"); + } + + size_t dst_nelems = dst.get_size(); + + size_t reduction_nelems(1); + for (int i = dst_nd; i < src_nd; ++i) { + reduction_nelems *= static_cast(src_shape_ptr[i]); + } + + // check that dst and src do not overlap + auto const &overlap = dpctl::tensor::overlap::MemoryOverlap(); + if (overlap(src, dst)) { + throw py::value_error("Arrays index overlapping segments of memory"); + } + + // destination must be ample enough to accommodate all elements + { + auto dst_offsets = dst.get_minmax_offsets(); + size_t range = + static_cast(dst_offsets.second - dst_offsets.first); + if (range + 1 < dst_nelems) { + throw py::value_error( + "Destination array can not accommodate all the " + "elements of source array."); + } + } + + int src_typenum = src.get_typenum(); + int dst_typenum = dst.get_typenum(); + + namespace td_ns = dpctl::tensor::type_dispatch; + const auto &array_types = td_ns::usm_ndarray_types(); + int src_typeid = array_types.typenum_to_lookup_id(src_typenum); + int dst_typeid = array_types.typenum_to_lookup_id(dst_typenum); + + void *data_ptr = dst.get_data(); + const auto &ctx = exec_q.get_context(); + auto usm_type = sycl::get_pointer_type(data_ptr, ctx); + + bool supports_atomics = check_atomic_support[dst_typeid](exec_q, usm_type); + + // handle special case when both reduction and iteration are 1D contiguous + bool is_src_c_contig = src.is_c_contiguous(); + bool is_dst_c_contig = dst.is_c_contiguous(); + bool is_src_f_contig = src.is_f_contiguous(); + + if ((is_src_c_contig && is_dst_c_contig) || + (is_src_f_contig && dst_nelems == 1)) + { + // remove_all_extents gets underlying type of table + using contig_fn_ptr_T = + typename std::remove_all_extents::type; + contig_fn_ptr_T fn; + if (supports_atomics) { + fn = axis1_atomic_dispatch_table[src_typeid][dst_typeid]; + } + else { + fn = axis1_temps_dispatch_table[src_typeid][dst_typeid]; + } + if (fn != nullptr) { + size_t iter_nelems = dst_nelems; + + constexpr py::ssize_t zero_offset = 0; + + sycl::event reduction_over_axis_contig_ev = + fn(exec_q, iter_nelems, reduction_nelems, src.get_data(), + dst.get_data(), + zero_offset, // iteration_src_offset + zero_offset, // iteration_dst_offset + zero_offset, // reduction_src_offset + depends); + + sycl::event keep_args_event = dpctl::utils::keep_args_alive( + exec_q, {src, dst}, {reduction_over_axis_contig_ev}); + + return std::make_pair(keep_args_event, + reduction_over_axis_contig_ev); + } + } + else if (is_src_f_contig && + ((is_dst_c_contig && dst_nd == 1) || dst.is_f_contiguous())) + { + // remove_all_extents gets underlying type of table + using contig_fn_ptr_T = + typename std::remove_all_extents::type; + contig_fn_ptr_T fn; + if (supports_atomics) { + fn = axis0_atomic_dispatch_table[src_typeid][dst_typeid]; + } + else { + fn = axis0_temps_dispatch_table[src_typeid][dst_typeid]; + } + if (fn != nullptr) { + size_t iter_nelems = dst_nelems; + + constexpr py::ssize_t zero_offset = 0; + + sycl::event reduction_over_axis_contig_ev = + fn(exec_q, iter_nelems, reduction_nelems, src.get_data(), + dst.get_data(), + zero_offset, // iteration_src_offset + zero_offset, // iteration_dst_offset + zero_offset, // reduction_src_offset + depends); + + sycl::event keep_args_event = dpctl::utils::keep_args_alive( + exec_q, {src, dst}, {reduction_over_axis_contig_ev}); + + return std::make_pair(keep_args_event, + reduction_over_axis_contig_ev); + } + } + + using dpctl::tensor::py_internal::simplify_iteration_space; + using dpctl::tensor::py_internal::simplify_iteration_space_1; + + auto const &src_shape_vecs = src.get_shape_vector(); + auto const &src_strides_vecs = src.get_strides_vector(); + auto const &dst_strides_vecs = dst.get_strides_vector(); + + int reduction_nd = trailing_dims_to_reduce; + const py::ssize_t *reduction_shape_ptr = src_shape_ptr + dst_nd; + using shT = std::vector; + shT reduction_src_strides(std::begin(src_strides_vecs) + dst_nd, + std::end(src_strides_vecs)); + + shT simplified_reduction_shape; + shT simplified_reduction_src_strides; + py::ssize_t reduction_src_offset(0); + + simplify_iteration_space_1( + reduction_nd, reduction_shape_ptr, reduction_src_strides, + // output + simplified_reduction_shape, simplified_reduction_src_strides, + reduction_src_offset); + + const py::ssize_t *iteration_shape_ptr = src_shape_ptr; + + shT iteration_src_strides(std::begin(src_strides_vecs), + std::begin(src_strides_vecs) + iteration_nd); + shT const &iteration_dst_strides = dst_strides_vecs; + + shT simplified_iteration_shape; + shT simplified_iteration_src_strides; + shT simplified_iteration_dst_strides; + py::ssize_t iteration_src_offset(0); + py::ssize_t iteration_dst_offset(0); + + if (iteration_nd == 0) { + if (dst_nelems != 1) { + throw std::runtime_error("iteration_nd == 0, but dst_nelems != 1"); + } + iteration_nd = 1; + simplified_iteration_shape.push_back(1); + simplified_iteration_src_strides.push_back(0); + simplified_iteration_dst_strides.push_back(0); + } + else { + simplify_iteration_space(iteration_nd, iteration_shape_ptr, + iteration_src_strides, iteration_dst_strides, + // output + simplified_iteration_shape, + simplified_iteration_src_strides, + simplified_iteration_dst_strides, + iteration_src_offset, iteration_dst_offset); + } + + if ((reduction_nd == 1) && (iteration_nd == 1)) { + bool mat_reduce_over_axis1 = false; + bool mat_reduce_over_axis0 = false; + bool array_reduce_all_elems = false; + size_t iter_nelems = dst_nelems; + + if (simplified_reduction_src_strides[0] == 1) { + array_reduce_all_elems = (simplified_iteration_shape[0] == 1); + mat_reduce_over_axis1 = + (simplified_iteration_dst_strides[0] == 1) && + (static_cast(simplified_iteration_src_strides[0]) == + reduction_nelems); + } + else if (static_cast(simplified_reduction_src_strides[0]) == + iter_nelems) + { + mat_reduce_over_axis0 = + (simplified_iteration_dst_strides[0] == 1) && + (simplified_iteration_src_strides[0] == 1); + } + + if (mat_reduce_over_axis1 || array_reduce_all_elems) { + using contig_fn_ptr_T = + typename std::remove_all_extents::type; + contig_fn_ptr_T fn; + if (supports_atomics) { + fn = axis1_atomic_dispatch_table[src_typeid][dst_typeid]; + } + else { + fn = axis1_temps_dispatch_table[src_typeid][dst_typeid]; + } + if (fn != nullptr) { + sycl::event reduction_over_axis1_contig_ev = + fn(exec_q, iter_nelems, reduction_nelems, src.get_data(), + dst.get_data(), iteration_src_offset, + iteration_dst_offset, reduction_src_offset, depends); + + sycl::event keep_args_event = dpctl::utils::keep_args_alive( + exec_q, {src, dst}, {reduction_over_axis1_contig_ev}); + + return std::make_pair(keep_args_event, + reduction_over_axis1_contig_ev); + } + } + else if (mat_reduce_over_axis0) { + using contig_fn_ptr_T = + typename std::remove_all_extents::type; + contig_fn_ptr_T fn; + if (supports_atomics) { + fn = axis1_atomic_dispatch_table[src_typeid][dst_typeid]; + } + else { + fn = axis1_temps_dispatch_table[src_typeid][dst_typeid]; + } + if (fn != nullptr) { + sycl::event reduction_over_axis0_contig_ev = + fn(exec_q, iter_nelems, reduction_nelems, src.get_data(), + dst.get_data(), iteration_src_offset, + iteration_dst_offset, reduction_src_offset, depends); + + sycl::event keep_args_event = dpctl::utils::keep_args_alive( + exec_q, {src, dst}, {reduction_over_axis0_contig_ev}); + + return std::make_pair(keep_args_event, + reduction_over_axis0_contig_ev); + } + } + } + + // remove_all_extents gets underlying type of table + using strided_fn_ptr_T = + typename std::remove_all_extents::type; + strided_fn_ptr_T fn = nullptr; + + if (supports_atomics) { + fn = atomic_dispatch_table[src_typeid][dst_typeid]; + } + + if (fn == nullptr) { + // use slower reduction implementation using temporaries + fn = temps_dispatch_table[src_typeid][dst_typeid]; + if (fn == nullptr) { + throw std::runtime_error("Datatypes are not supported"); + } + } + + std::vector host_task_events{}; + using dpctl::tensor::offset_utils::device_allocate_and_pack; + const auto &arrays_metainfo_packing_triple_ = + device_allocate_and_pack( + exec_q, host_task_events, + // iteration metadata + simplified_iteration_shape, simplified_iteration_src_strides, + simplified_iteration_dst_strides, + // reduction metadata + simplified_reduction_shape, simplified_reduction_src_strides); + py::ssize_t *temp_allocation_ptr = + std::get<0>(arrays_metainfo_packing_triple_); + if (temp_allocation_ptr == nullptr) { + throw std::runtime_error("Unable to allocate memory on device"); + } + const auto ©_metadata_ev = std::get<2>(arrays_metainfo_packing_triple_); + + py::ssize_t *iter_shape_and_strides = temp_allocation_ptr; + py::ssize_t *reduction_shape_stride = + temp_allocation_ptr + 3 * simplified_iteration_shape.size(); + + std::vector all_deps; + all_deps.reserve(depends.size() + 1); + all_deps.resize(depends.size()); + std::copy(depends.begin(), depends.end(), all_deps.begin()); + all_deps.push_back(copy_metadata_ev); + + auto reduction_ev = + fn(exec_q, dst_nelems, reduction_nelems, src.get_data(), dst.get_data(), + iteration_nd, iter_shape_and_strides, iteration_src_offset, + iteration_dst_offset, + reduction_nd, // number dimensions being reduced + reduction_shape_stride, reduction_src_offset, all_deps); + + sycl::event temp_cleanup_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(reduction_ev); + const auto &ctx = exec_q.get_context(); + cgh.host_task([ctx, temp_allocation_ptr] { + sycl::free(temp_allocation_ptr, ctx); + }); + }); + host_task_events.push_back(temp_cleanup_ev); + + sycl::event keep_args_event = + dpctl::utils::keep_args_alive(exec_q, {src, dst}, host_task_events); + + return std::make_pair(keep_args_event, reduction_ev); +} + +/* ================= No atomic reductions ====================== */ + +/*! @brief Template implementing Python API for reduction over axis without + * atomics */ +template +std::pair py_tree_reduction_over_axis( + const dpctl::tensor::usm_ndarray &src, + int trailing_dims_to_reduce, // comp over this many trailing indexes + const dpctl::tensor::usm_ndarray &dst, + sycl::queue &exec_q, + const std::vector &depends, + const strided_fnT &temps_dispatch_table, + const contig_fnT &axis0_temps_dispatch_table, + const contig_fnT &axis1_temps_dispatch_table) +{ + int src_nd = src.get_ndim(); + int iteration_nd = src_nd - trailing_dims_to_reduce; + if (trailing_dims_to_reduce <= 0 || iteration_nd < 0) { + throw py::value_error("Trailing_dim_to_reduce must be positive, but no " + "greater than rank of the array being reduced"); + } + + int dst_nd = dst.get_ndim(); + if (dst_nd != iteration_nd) { + throw py::value_error("Destination array rank does not match input " + "array rank and number of reduced dimensions"); + } + + const py::ssize_t *src_shape_ptr = src.get_shape_raw(); + const py::ssize_t *dst_shape_ptr = dst.get_shape_raw(); + + bool same_shapes = true; + for (int i = 0; same_shapes && (i < dst_nd); ++i) { + same_shapes = same_shapes && (src_shape_ptr[i] == dst_shape_ptr[i]); + } + + if (!same_shapes) { + throw py::value_error("Destination shape does not match unreduced " + "dimensions of the input shape"); + } + + if (!dpctl::utils::queues_are_compatible(exec_q, {src, dst})) { + throw py::value_error( + "Execution queue is not compatible with allocation queues"); + } + + size_t dst_nelems = dst.get_size(); + + size_t reduction_nelems(1); + for (int i = dst_nd; i < src_nd; ++i) { + reduction_nelems *= static_cast(src_shape_ptr[i]); + } + + // check that dst and src do not overlap + auto const &overlap = dpctl::tensor::overlap::MemoryOverlap(); + if (overlap(src, dst)) { + throw py::value_error("Arrays index overlapping segments of memory"); + } + + // destination must be ample enough to accommodate all elements + { + auto dst_offsets = dst.get_minmax_offsets(); + size_t range = + static_cast(dst_offsets.second - dst_offsets.first); + if (range + 1 < dst_nelems) { + throw py::value_error( + "Destination array can not accommodate all the " + "elements of source array."); + } + } + + int src_typenum = src.get_typenum(); + int dst_typenum = dst.get_typenum(); + + namespace td_ns = dpctl::tensor::type_dispatch; + const auto &array_types = td_ns::usm_ndarray_types(); + int src_typeid = array_types.typenum_to_lookup_id(src_typenum); + int dst_typeid = array_types.typenum_to_lookup_id(dst_typenum); + + // handle special case when both reduction and iteration are 1D contiguous + bool is_src_c_contig = src.is_c_contiguous(); + bool is_dst_c_contig = dst.is_c_contiguous(); + bool is_src_f_contig = src.is_f_contiguous(); + + if ((is_src_c_contig && is_dst_c_contig) || + (is_src_f_contig && dst_nelems == 1)) + { + auto fn = axis1_temps_dispatch_table[src_typeid][dst_typeid]; + if (fn != nullptr) { + size_t iter_nelems = dst_nelems; + + constexpr py::ssize_t zero_offset = 0; + + sycl::event reduction_over_axis_contig_ev = + fn(exec_q, iter_nelems, reduction_nelems, src.get_data(), + dst.get_data(), + zero_offset, // iteration_src_offset + zero_offset, // iteration_dst_offset + zero_offset, // reduction_src_offset + depends); + + sycl::event keep_args_event = dpctl::utils::keep_args_alive( + exec_q, {src, dst}, {reduction_over_axis_contig_ev}); + + return std::make_pair(keep_args_event, + reduction_over_axis_contig_ev); + } + } + else if (is_src_f_contig && + ((is_dst_c_contig && dst_nd == 1) || dst.is_f_contiguous())) + { + auto fn = axis0_temps_dispatch_table[src_typeid][dst_typeid]; + if (fn != nullptr) { + size_t iter_nelems = dst_nelems; + + constexpr py::ssize_t zero_offset = 0; + + sycl::event reduction_over_axis_contig_ev = + fn(exec_q, iter_nelems, reduction_nelems, src.get_data(), + dst.get_data(), + zero_offset, // iteration_src_offset + zero_offset, // iteration_dst_offset + zero_offset, // reduction_src_offset + depends); + + sycl::event keep_args_event = dpctl::utils::keep_args_alive( + exec_q, {src, dst}, {reduction_over_axis_contig_ev}); + + return std::make_pair(keep_args_event, + reduction_over_axis_contig_ev); + } + } + + using dpctl::tensor::py_internal::simplify_iteration_space; + using dpctl::tensor::py_internal::simplify_iteration_space_1; + + auto const &src_shape_vecs = src.get_shape_vector(); + auto const &src_strides_vecs = src.get_strides_vector(); + auto const &dst_strides_vecs = dst.get_strides_vector(); + + int reduction_nd = trailing_dims_to_reduce; + const py::ssize_t *reduction_shape_ptr = src_shape_ptr + dst_nd; + using shT = std::vector; + shT reduction_src_strides(std::begin(src_strides_vecs) + dst_nd, + std::end(src_strides_vecs)); + + shT simplified_reduction_shape; + shT simplified_reduction_src_strides; + py::ssize_t reduction_src_offset(0); + + simplify_iteration_space_1( + reduction_nd, reduction_shape_ptr, reduction_src_strides, + // output + simplified_reduction_shape, simplified_reduction_src_strides, + reduction_src_offset); + + const py::ssize_t *iteration_shape_ptr = src_shape_ptr; + + shT iteration_src_strides(std::begin(src_strides_vecs), + std::begin(src_strides_vecs) + iteration_nd); + shT const &iteration_dst_strides = dst_strides_vecs; + + shT simplified_iteration_shape; + shT simplified_iteration_src_strides; + shT simplified_iteration_dst_strides; + py::ssize_t iteration_src_offset(0); + py::ssize_t iteration_dst_offset(0); + + if (iteration_nd == 0) { + if (dst_nelems != 1) { + throw std::runtime_error("iteration_nd == 0, but dst_nelems != 1"); + } + iteration_nd = 1; + simplified_iteration_shape.push_back(1); + simplified_iteration_src_strides.push_back(0); + simplified_iteration_dst_strides.push_back(0); + } + else { + simplify_iteration_space(iteration_nd, iteration_shape_ptr, + iteration_src_strides, iteration_dst_strides, + // output + simplified_iteration_shape, + simplified_iteration_src_strides, + simplified_iteration_dst_strides, + iteration_src_offset, iteration_dst_offset); + } + + if ((reduction_nd == 1) && (iteration_nd == 1)) { + bool mat_reduce_over_axis1 = false; + bool mat_reduce_over_axis0 = false; + bool array_reduce_all_elems = false; + size_t iter_nelems = dst_nelems; + + if (simplified_reduction_src_strides[0] == 1) { + array_reduce_all_elems = (simplified_iteration_shape[0] == 1); + mat_reduce_over_axis1 = + (simplified_iteration_dst_strides[0] == 1) && + (static_cast(simplified_iteration_src_strides[0]) == + reduction_nelems); + } + else if (static_cast(simplified_reduction_src_strides[0]) == + iter_nelems) + { + mat_reduce_over_axis0 = + (simplified_iteration_dst_strides[0] == 1) && + (simplified_iteration_src_strides[0] == 1); + } + + if (mat_reduce_over_axis1 || array_reduce_all_elems) { + auto fn = axis1_temps_dispatch_table[src_typeid][dst_typeid]; + if (fn != nullptr) { + sycl::event reduction_over_axis1_contig_ev = + fn(exec_q, iter_nelems, reduction_nelems, src.get_data(), + dst.get_data(), iteration_src_offset, + iteration_dst_offset, reduction_src_offset, depends); + + sycl::event keep_args_event = dpctl::utils::keep_args_alive( + exec_q, {src, dst}, {reduction_over_axis1_contig_ev}); + + return std::make_pair(keep_args_event, + reduction_over_axis1_contig_ev); + } + } + else if (mat_reduce_over_axis0) { + auto fn = axis1_temps_dispatch_table[src_typeid][dst_typeid]; + if (fn != nullptr) { + sycl::event reduction_over_axis0_contig_ev = + fn(exec_q, iter_nelems, reduction_nelems, src.get_data(), + dst.get_data(), iteration_src_offset, + iteration_dst_offset, reduction_src_offset, depends); + + sycl::event keep_args_event = dpctl::utils::keep_args_alive( + exec_q, {src, dst}, {reduction_over_axis0_contig_ev}); + + return std::make_pair(keep_args_event, + reduction_over_axis0_contig_ev); + } + } + } + + auto fn = temps_dispatch_table[src_typeid][dst_typeid]; + if (fn == nullptr) { + throw std::runtime_error("Datatypes are not supported"); + } + + std::vector host_task_events{}; + using dpctl::tensor::offset_utils::device_allocate_and_pack; + const auto &arrays_metainfo_packing_triple_ = + device_allocate_and_pack( + exec_q, host_task_events, + // iteration metadata + simplified_iteration_shape, simplified_iteration_src_strides, + simplified_iteration_dst_strides, + // reduction metadata + simplified_reduction_shape, simplified_reduction_src_strides); + py::ssize_t *temp_allocation_ptr = + std::get<0>(arrays_metainfo_packing_triple_); + if (temp_allocation_ptr == nullptr) { + throw std::runtime_error("Unable to allocate memory on device"); + } + const auto ©_metadata_ev = std::get<2>(arrays_metainfo_packing_triple_); + + py::ssize_t *iter_shape_and_strides = temp_allocation_ptr; + py::ssize_t *reduction_shape_stride = + temp_allocation_ptr + 3 * simplified_iteration_shape.size(); + + std::vector all_deps; + all_deps.reserve(depends.size() + 1); + all_deps.resize(depends.size()); + std::copy(depends.begin(), depends.end(), all_deps.begin()); + all_deps.push_back(copy_metadata_ev); + + auto reduction_ev = + fn(exec_q, dst_nelems, reduction_nelems, src.get_data(), dst.get_data(), + iteration_nd, iter_shape_and_strides, iteration_src_offset, + iteration_dst_offset, + reduction_nd, // number dimensions being reduced + reduction_shape_stride, reduction_src_offset, all_deps); + + sycl::event temp_cleanup_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(reduction_ev); + const auto &ctx = exec_q.get_context(); + cgh.host_task([ctx, temp_allocation_ptr] { + sycl::free(temp_allocation_ptr, ctx); + }); + }); + host_task_events.push_back(temp_cleanup_ev); + + sycl::event keep_args_event = + dpctl::utils::keep_args_alive(exec_q, {src, dst}, host_task_events); + + return std::make_pair(keep_args_event, reduction_ev); +} + +/*! @brief Template implementing Python API for searching over an axis */ +template +std::pair py_search_over_axis( + const dpctl::tensor::usm_ndarray &src, + int trailing_dims_to_reduce, // comp over this many trailing indexes + const dpctl::tensor::usm_ndarray &dst, + sycl::queue &exec_q, + const std::vector &depends, + const strided_fnT &strided_dispatch_table, + const contig_fnT &axis0_contig_dispatch_table, + const contig_fnT &axis1_contig_dispatch_table) +{ + int src_nd = src.get_ndim(); + int iteration_nd = src_nd - trailing_dims_to_reduce; + if (trailing_dims_to_reduce <= 0 || iteration_nd < 0) { + throw py::value_error("Trailing_dim_to_reduce must be positive, but no " + "greater than rank of the array being reduced"); + } + + int dst_nd = dst.get_ndim(); + if (dst_nd != iteration_nd) { + throw py::value_error("Destination array rank does not match input " + "array rank and number of reduced dimensions"); + } + + const py::ssize_t *src_shape_ptr = src.get_shape_raw(); + const py::ssize_t *dst_shape_ptr = dst.get_shape_raw(); + + bool same_shapes = true; + for (int i = 0; same_shapes && (i < dst_nd); ++i) { + same_shapes = same_shapes && (src_shape_ptr[i] == dst_shape_ptr[i]); + } + + if (!same_shapes) { + throw py::value_error("Destination shape does not match unreduced " + "dimensions of the input shape"); + } + + if (!dpctl::utils::queues_are_compatible(exec_q, {src, dst})) { + throw py::value_error( + "Execution queue is not compatible with allocation queues"); + } + + size_t dst_nelems = dst.get_size(); + + size_t reduction_nelems(1); + for (int i = dst_nd; i < src_nd; ++i) { + reduction_nelems *= static_cast(src_shape_ptr[i]); + } + + // check that dst and src do not overlap + auto const &overlap = dpctl::tensor::overlap::MemoryOverlap(); + if (overlap(src, dst)) { + throw py::value_error("Arrays index overlapping segments of memory"); + } + + // destination must be ample enough to accommodate all elements + { + auto dst_offsets = dst.get_minmax_offsets(); + size_t range = + static_cast(dst_offsets.second - dst_offsets.first); + if (range + 1 < dst_nelems) { + throw py::value_error( + "Destination array can not accommodate all the " + "elements of source array."); + } + } + + int src_typenum = src.get_typenum(); + int dst_typenum = dst.get_typenum(); + + namespace td_ns = dpctl::tensor::type_dispatch; + const auto &array_types = td_ns::usm_ndarray_types(); + int src_typeid = array_types.typenum_to_lookup_id(src_typenum); + int dst_typeid = array_types.typenum_to_lookup_id(dst_typenum); + + // handle special case when both reduction and iteration are 1D contiguous + // and can be done with atomics + bool is_src_c_contig = src.is_c_contiguous(); + bool is_dst_c_contig = dst.is_c_contiguous(); + bool is_src_f_contig = src.is_f_contiguous(); + + if ((is_src_c_contig && is_dst_c_contig) || + (is_src_f_contig && dst_nelems == 1)) + { + auto fn = axis1_contig_dispatch_table[src_typeid][dst_typeid]; + if (fn != nullptr) { + size_t iter_nelems = dst_nelems; + + constexpr py::ssize_t zero_offset = 0; + + sycl::event reduction_over_axis_contig_ev = + fn(exec_q, iter_nelems, reduction_nelems, src.get_data(), + dst.get_data(), + zero_offset, // iteration_src_offset + zero_offset, // iteration_dst_offset + zero_offset, // reduction_src_offset + depends); + + sycl::event keep_args_event = dpctl::utils::keep_args_alive( + exec_q, {src, dst}, {reduction_over_axis_contig_ev}); + + return std::make_pair(keep_args_event, + reduction_over_axis_contig_ev); + } + } + else if (is_src_f_contig && + ((is_dst_c_contig && dst_nd == 1) || dst.is_f_contiguous())) + { + auto fn = axis0_contig_dispatch_table[src_typeid][dst_typeid]; + if (fn != nullptr) { + size_t iter_nelems = dst_nelems; + + constexpr py::ssize_t zero_offset = 0; + + sycl::event reduction_over_axis_contig_ev = + fn(exec_q, iter_nelems, reduction_nelems, src.get_data(), + dst.get_data(), + zero_offset, // iteration_src_offset + zero_offset, // iteration_dst_offset + zero_offset, // reduction_src_offset + depends); + + sycl::event keep_args_event = dpctl::utils::keep_args_alive( + exec_q, {src, dst}, {reduction_over_axis_contig_ev}); + + return std::make_pair(keep_args_event, + reduction_over_axis_contig_ev); + } + } + + using dpctl::tensor::py_internal::simplify_iteration_space; + using dpctl::tensor::py_internal::simplify_iteration_space_1; + + auto const &src_shape_vecs = src.get_shape_vector(); + auto const &src_strides_vecs = src.get_strides_vector(); + auto const &dst_strides_vecs = dst.get_strides_vector(); + + int reduction_nd = trailing_dims_to_reduce; + const py::ssize_t *reduction_shape_ptr = src_shape_ptr + dst_nd; + using shT = std::vector; + shT reduction_src_strides(std::begin(src_strides_vecs) + dst_nd, + std::end(src_strides_vecs)); + + shT compact_reduction_shape; + shT compact_reduction_src_strides; + py::ssize_t reduction_src_offset(0); + + compact_iteration_space( + reduction_nd, reduction_shape_ptr, reduction_src_strides, + // output + compact_reduction_shape, compact_reduction_src_strides); + + const py::ssize_t *iteration_shape_ptr = src_shape_ptr; + + shT iteration_src_strides(std::begin(src_strides_vecs), + std::begin(src_strides_vecs) + iteration_nd); + shT const &iteration_dst_strides = dst_strides_vecs; + + shT simplified_iteration_shape; + shT simplified_iteration_src_strides; + shT simplified_iteration_dst_strides; + py::ssize_t iteration_src_offset(0); + py::ssize_t iteration_dst_offset(0); + + if (iteration_nd == 0) { + if (dst_nelems != 1) { + throw std::runtime_error("iteration_nd == 0, but dst_nelems != 1"); + } + iteration_nd = 1; + simplified_iteration_shape.push_back(1); + simplified_iteration_src_strides.push_back(0); + simplified_iteration_dst_strides.push_back(0); + } + else { + simplify_iteration_space(iteration_nd, iteration_shape_ptr, + iteration_src_strides, iteration_dst_strides, + // output + simplified_iteration_shape, + simplified_iteration_src_strides, + simplified_iteration_dst_strides, + iteration_src_offset, iteration_dst_offset); + } + + if ((reduction_nd == 1) && (iteration_nd == 1)) { + bool mat_reduce_over_axis1 = false; + bool mat_reduce_over_axis0 = false; + bool array_reduce_all_elems = false; + size_t iter_nelems = dst_nelems; + + if (compact_reduction_src_strides[0] == 1) { + array_reduce_all_elems = (simplified_iteration_shape[0] == 1); + mat_reduce_over_axis1 = + (simplified_iteration_dst_strides[0] == 1) && + (static_cast(simplified_iteration_src_strides[0]) == + reduction_nelems); + } + else if (static_cast(compact_reduction_src_strides[0]) == + iter_nelems) { + mat_reduce_over_axis0 = + (simplified_iteration_dst_strides[0] == 1) && + (simplified_iteration_src_strides[0] == 1); + } + + if (mat_reduce_over_axis1 || array_reduce_all_elems) { + auto fn = axis1_contig_dispatch_table[src_typeid][dst_typeid]; + if (fn != nullptr) { + sycl::event reduction_over_axis1_contig_ev = + fn(exec_q, iter_nelems, reduction_nelems, src.get_data(), + dst.get_data(), iteration_src_offset, + iteration_dst_offset, reduction_src_offset, depends); + + sycl::event keep_args_event = dpctl::utils::keep_args_alive( + exec_q, {src, dst}, {reduction_over_axis1_contig_ev}); + + return std::make_pair(keep_args_event, + reduction_over_axis1_contig_ev); + } + } + else if (mat_reduce_over_axis0) { + auto fn = axis0_contig_dispatch_table[src_typeid][dst_typeid]; + if (fn != nullptr) { + sycl::event reduction_over_axis0_contig_ev = + fn(exec_q, iter_nelems, reduction_nelems, src.get_data(), + dst.get_data(), iteration_src_offset, + iteration_dst_offset, reduction_src_offset, depends); + + sycl::event keep_args_event = dpctl::utils::keep_args_alive( + exec_q, {src, dst}, {reduction_over_axis0_contig_ev}); + + return std::make_pair(keep_args_event, + reduction_over_axis0_contig_ev); + } + } + } + + auto fn = strided_dispatch_table[src_typeid][dst_typeid]; + if (fn == nullptr) { + throw std::runtime_error("Datatypes are not supported"); + } + + std::vector host_task_events{}; + + using dpctl::tensor::offset_utils::device_allocate_and_pack; + + const auto &arrays_metainfo_packing_triple_ = + device_allocate_and_pack( + exec_q, host_task_events, + // iteration metadata + simplified_iteration_shape, simplified_iteration_src_strides, + simplified_iteration_dst_strides, + // reduction metadata + compact_reduction_shape, compact_reduction_src_strides); + py::ssize_t *temp_allocation_ptr = + std::get<0>(arrays_metainfo_packing_triple_); + if (temp_allocation_ptr == nullptr) { + throw std::runtime_error("Unable to allocate memory on device"); + } + const auto ©_metadata_ev = std::get<2>(arrays_metainfo_packing_triple_); + + py::ssize_t *iter_shape_and_strides = temp_allocation_ptr; + py::ssize_t *reduction_shape_stride = + temp_allocation_ptr + 3 * simplified_iteration_shape.size(); + + std::vector all_deps; + all_deps.reserve(depends.size() + 1); + all_deps.resize(depends.size()); + std::copy(depends.begin(), depends.end(), all_deps.begin()); + all_deps.push_back(copy_metadata_ev); + + auto comp_ev = fn(exec_q, dst_nelems, reduction_nelems, src.get_data(), + dst.get_data(), iteration_nd, iter_shape_and_strides, + iteration_src_offset, iteration_dst_offset, + reduction_nd, // number dimensions being reduced + reduction_shape_stride, reduction_src_offset, all_deps); + + sycl::event temp_cleanup_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(comp_ev); + const auto &ctx = exec_q.get_context(); + cgh.host_task([ctx, temp_allocation_ptr] { + sycl::free(temp_allocation_ptr, ctx); + }); + }); + host_task_events.push_back(temp_cleanup_ev); + + sycl::event keep_args_event = + dpctl::utils::keep_args_alive(exec_q, {src, dst}, host_task_events); + + return std::make_pair(keep_args_event, comp_ev); +} + +extern void init_reduction_functions(py::module_ m); + +} // namespace py_internal +} // namespace tensor +} // namespace dpctl diff --git a/dpctl/tensor/libtensor/source/reductions/sum.cpp b/dpctl/tensor/libtensor/source/reductions/sum.cpp new file mode 100644 index 0000000000..33803cfd7b --- /dev/null +++ b/dpctl/tensor/libtensor/source/reductions/sum.cpp @@ -0,0 +1,187 @@ +//===-- ------------ Implementation of _tensor_impl module ----*-C++-*-/===// +// +// Data Parallel Control (dpctl) +// +// Copyright 2020-2023 Intel Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//===--------------------------------------------------------------------===// +/// +/// \file +/// This file defines functions of dpctl.tensor._tensor_impl extensions +//===--------------------------------------------------------------------===// + +#include "dpctl4pybind11.hpp" +#include +#include +#include +#include +#include + +#include "kernels/reductions.hpp" +#include "utils/type_dispatch.hpp" + +#include "reduction_atomic_support.hpp" +#include "reduction_over_axis.hpp" + +namespace py = pybind11; + +namespace dpctl +{ +namespace tensor +{ +namespace py_internal +{ + +namespace td_ns = dpctl::tensor::type_dispatch; + +namespace impl +{ + +using dpctl::tensor::kernels::reduction_strided_impl_fn_ptr; +static reduction_strided_impl_fn_ptr + sum_over_axis_strided_atomic_dispatch_table[td_ns::num_types] + [td_ns::num_types]; +static reduction_strided_impl_fn_ptr + sum_over_axis_strided_temps_dispatch_table[td_ns::num_types] + [td_ns::num_types]; + +using dpctl::tensor::kernels::reduction_contig_impl_fn_ptr; +static reduction_contig_impl_fn_ptr + sum_over_axis1_contig_atomic_dispatch_table[td_ns::num_types] + [td_ns::num_types]; +static reduction_contig_impl_fn_ptr + sum_over_axis0_contig_atomic_dispatch_table[td_ns::num_types] + [td_ns::num_types]; +static reduction_contig_impl_fn_ptr + sum_over_axis1_contig_temps_dispatch_table[td_ns::num_types] + [td_ns::num_types]; +static reduction_contig_impl_fn_ptr + sum_over_axis0_contig_temps_dispatch_table[td_ns::num_types] + [td_ns::num_types]; + +void populate_sum_over_axis_dispatch_tables(void) +{ + using dpctl::tensor::kernels::reduction_contig_impl_fn_ptr; + using dpctl::tensor::kernels::reduction_strided_impl_fn_ptr; + using namespace td_ns; + + using dpctl::tensor::kernels::SumOverAxisAtomicStridedFactory; + DispatchTableBuilder + dtb1; + dtb1.populate_dispatch_table(sum_over_axis_strided_atomic_dispatch_table); + + using dpctl::tensor::kernels::SumOverAxisTempsStridedFactory; + DispatchTableBuilder + dtb2; + dtb2.populate_dispatch_table(sum_over_axis_strided_temps_dispatch_table); + + using dpctl::tensor::kernels::SumOverAxis1AtomicContigFactory; + DispatchTableBuilder + dtb3; + dtb3.populate_dispatch_table(sum_over_axis1_contig_atomic_dispatch_table); + + using dpctl::tensor::kernels::SumOverAxis0AtomicContigFactory; + DispatchTableBuilder + dtb4; + dtb4.populate_dispatch_table(sum_over_axis0_contig_atomic_dispatch_table); + + using dpctl::tensor::kernels::SumOverAxis1TempsContigFactory; + DispatchTableBuilder + dtb5; + dtb5.populate_dispatch_table(sum_over_axis1_contig_temps_dispatch_table); + + using dpctl::tensor::kernels::SumOverAxis0TempsContigFactory; + DispatchTableBuilder + dtb6; + dtb6.populate_dispatch_table(sum_over_axis0_contig_temps_dispatch_table); +} + +using atomic_support::atomic_support_fn_ptr_t; +static atomic_support_fn_ptr_t sum_atomic_support_vector[td_ns::num_types]; + +void populate_sum_atomic_support_dispatch_vector(void) +{ + using td_ns::DispatchVectorBuilder; + + using atomic_support::SumAtomicSupportFactory; + DispatchVectorBuilder + dvb; + dvb.populate_dispatch_vector(sum_atomic_support_vector); +} + +} // namespace impl + +void init_sum(py::module_ m) +{ + using arrayT = dpctl::tensor::usm_ndarray; + using event_vecT = std::vector; + { + using impl::populate_sum_over_axis_dispatch_tables; + populate_sum_over_axis_dispatch_tables(); + using impl::sum_over_axis0_contig_atomic_dispatch_table; + using impl::sum_over_axis0_contig_temps_dispatch_table; + using impl::sum_over_axis1_contig_atomic_dispatch_table; + using impl::sum_over_axis1_contig_temps_dispatch_table; + using impl::sum_over_axis_strided_atomic_dispatch_table; + using impl::sum_over_axis_strided_temps_dispatch_table; + + using impl::populate_sum_atomic_support_dispatch_vector; + populate_sum_atomic_support_dispatch_vector(); + using impl::sum_atomic_support_vector; + + auto sum_pyapi = [&](const arrayT &src, int trailing_dims_to_reduce, + const arrayT &dst, sycl::queue &exec_q, + const event_vecT &depends = {}) { + using dpctl::tensor::py_internal::py_reduction_over_axis; + return py_reduction_over_axis( + src, trailing_dims_to_reduce, dst, exec_q, depends, + sum_over_axis_strided_atomic_dispatch_table, + sum_over_axis0_contig_atomic_dispatch_table, + sum_over_axis1_contig_atomic_dispatch_table, + sum_over_axis_strided_temps_dispatch_table, + sum_over_axis0_contig_temps_dispatch_table, + sum_over_axis1_contig_temps_dispatch_table, + sum_atomic_support_vector); + }; + m.def("_sum_over_axis", sum_pyapi, "", py::arg("src"), + py::arg("trailing_dims_to_reduce"), py::arg("dst"), + py::arg("sycl_queue"), py::arg("depends") = py::list()); + + auto sum_dtype_supported = + [&](const py::dtype &input_dtype, const py::dtype &output_dtype, + const std::string &dst_usm_type, sycl::queue &q) { + using dpctl::tensor::py_internal::py_reduction_dtype_supported; + return py_reduction_dtype_supported( + input_dtype, output_dtype, dst_usm_type, q, + sum_over_axis_strided_atomic_dispatch_table, + sum_over_axis_strided_temps_dispatch_table, + sum_atomic_support_vector); + }; + m.def("_sum_over_axis_dtype_supported", sum_dtype_supported, "", + py::arg("arg_dtype"), py::arg("out_dtype"), + py::arg("dst_usm_type"), py::arg("sycl_queue")); + } +} + +} // namespace py_internal +} // namespace tensor +} // namespace dpctl diff --git a/dpctl/tensor/libtensor/source/reductions/sum.hpp b/dpctl/tensor/libtensor/source/reductions/sum.hpp new file mode 100644 index 0000000000..ded0d14809 --- /dev/null +++ b/dpctl/tensor/libtensor/source/reductions/sum.hpp @@ -0,0 +1,41 @@ +//===-- ------------ Implementation of _tensor_impl module ----*-C++-*-/===// +// +// Data Parallel Control (dpctl) +// +// Copyright 2020-2023 Intel Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//===--------------------------------------------------------------------===// +/// +/// \file +/// This file defines functions of dpctl.tensor._tensor_impl extensions +//===--------------------------------------------------------------------===// + +#pragma once +#include + +namespace py = pybind11; + +namespace dpctl +{ +namespace tensor +{ +namespace py_internal +{ + +extern void init_sum(py::module_ m); + +} // namespace py_internal +} // namespace tensor +} // namespace dpctl diff --git a/dpctl/tensor/libtensor/source/tensor_py.cpp b/dpctl/tensor/libtensor/source/tensor_py.cpp index 254856ec38..d07d5cf084 100644 --- a/dpctl/tensor/libtensor/source/tensor_py.cpp +++ b/dpctl/tensor/libtensor/source/tensor_py.cpp @@ -48,7 +48,7 @@ #include "full_ctor.hpp" #include "integer_advanced_indexing.hpp" #include "linear_sequences.hpp" -#include "reduction_over_axis.hpp" +#include "reductions/reduction_common.hpp" #include "repeat.hpp" #include "simplify_iteration_space.hpp" #include "triul_ctor.hpp" diff --git a/dpctl/tests/test_tensor_sum.py b/dpctl/tests/test_tensor_sum.py index f6d1ca086b..a4e202f073 100644 --- a/dpctl/tests/test_tensor_sum.py +++ b/dpctl/tests/test_tensor_sum.py @@ -173,6 +173,21 @@ def test_largish_reduction(arg_dtype, n): assert dpt.all(dpt.equal(y1, n * m)) +@pytest.mark.parametrize("n", [1023, 1024, 1025]) +def test_largish_reduction_axis1_axis0(n): + get_queue_or_skip() + + m = 25 + x1 = dpt.ones((m, n), dtype="f4") + x2 = dpt.ones((n, m), dtype="f4") + + y1 = dpt.sum(x1, axis=1) + y2 = dpt.sum(x2, axis=0) + + assert dpt.all(y1 == n) + assert dpt.all(y2 == n) + + def test_axis0_bug(): "gh-1391" get_queue_or_skip() diff --git a/dpctl/tests/test_usm_ndarray_reductions.py b/dpctl/tests/test_usm_ndarray_reductions.py index 8d66f35d71..73cf9459a7 100644 --- a/dpctl/tests/test_usm_ndarray_reductions.py +++ b/dpctl/tests/test_usm_ndarray_reductions.py @@ -18,10 +18,32 @@ import numpy as np import pytest +from numpy.testing import assert_allclose import dpctl.tensor as dpt from dpctl.tests.helper import get_queue_or_skip, skip_if_dtype_not_supported +_no_complex_dtypes = [ + "?", + "i1", + "u1", + "i2", + "u2", + "i4", + "u4", + "i8", + "u8", + "f2", + "f4", + "f8", +] + + +_all_dtypes = _no_complex_dtypes + [ + "c8", + "c16", +] + def test_max_min_axis(): get_queue_or_skip() @@ -234,3 +256,176 @@ def test_reduction_arg_validation(): dpt.max(x) with pytest.raises(ValueError): dpt.argmax(x) + + +@pytest.mark.parametrize("arg_dtype", _no_complex_dtypes[1:]) +def test_logsumexp_arg_dtype_default_output_dtype_matrix(arg_dtype): + q = get_queue_or_skip() + skip_if_dtype_not_supported(arg_dtype, q) + + m = dpt.ones(100, dtype=arg_dtype) + r = dpt.logsumexp(m) + + assert isinstance(r, dpt.usm_ndarray) + assert r.dtype.kind == "f" + tol = dpt.finfo(r.dtype).resolution + assert_allclose( + dpt.asnumpy(r), + np.logaddexp.reduce(dpt.asnumpy(m), dtype=r.dtype), + rtol=tol, + atol=tol, + ) + + +def test_logsumexp_empty(): + get_queue_or_skip() + x = dpt.empty((0,), dtype="f4") + y = dpt.logsumexp(x) + assert y.shape == tuple() + assert y == -dpt.inf + + +def test_logsumexp_axis(): + get_queue_or_skip() + + m = dpt.ones((3, 4, 5, 6, 7), dtype="f4") + s = dpt.logsumexp(m, axis=(1, 2, -1)) + + assert isinstance(s, dpt.usm_ndarray) + assert s.shape == (3, 6) + tol = dpt.finfo(s.dtype).resolution + assert_allclose( + dpt.asnumpy(s), + np.logaddexp.reduce(dpt.asnumpy(m), axis=(1, 2, -1), dtype=s.dtype), + rtol=tol, + atol=tol, + ) + + +@pytest.mark.parametrize("arg_dtype", _no_complex_dtypes[1:]) +@pytest.mark.parametrize("out_dtype", _all_dtypes[1:]) +def test_logsumexp_arg_out_dtype_matrix(arg_dtype, out_dtype): + q = get_queue_or_skip() + skip_if_dtype_not_supported(arg_dtype, q) + skip_if_dtype_not_supported(out_dtype, q) + + m = dpt.ones(100, dtype=arg_dtype) + r = dpt.logsumexp(m, dtype=out_dtype) + + assert isinstance(r, dpt.usm_ndarray) + assert r.dtype == dpt.dtype(out_dtype) + + +def test_logsumexp_keepdims(): + get_queue_or_skip() + + m = dpt.ones((3, 4, 5, 6, 7), dtype="i4") + s = dpt.logsumexp(m, axis=(1, 2, -1), keepdims=True) + + assert isinstance(s, dpt.usm_ndarray) + assert s.shape == (3, 1, 1, 6, 1) + + +def test_logsumexp_keepdims_zero_size(): + get_queue_or_skip() + n = 10 + a = dpt.ones((n, 0, n)) + + s1 = dpt.logsumexp(a, keepdims=True) + assert s1.shape == (1, 1, 1) + + s2 = dpt.logsumexp(a, axis=(0, 1), keepdims=True) + assert s2.shape == (1, 1, n) + + s3 = dpt.logsumexp(a, axis=(1, 2), keepdims=True) + assert s3.shape == (n, 1, 1) + + s4 = dpt.logsumexp(a, axis=(0, 2), keepdims=True) + assert s4.shape == (1, 0, 1) + + a0 = a[0] + s5 = dpt.logsumexp(a0, keepdims=True) + assert s5.shape == (1, 1) + + +def test_logsumexp_scalar(): + get_queue_or_skip() + + m = dpt.ones(()) + s = dpt.logsumexp(m) + + assert isinstance(s, dpt.usm_ndarray) + assert m.sycl_queue == s.sycl_queue + assert s.shape == () + + +def test_logsumexp_complex(): + get_queue_or_skip() + + x = dpt.zeros(1, dtype="c8") + with pytest.raises(TypeError): + dpt.logsumexp(x) + + +def test_logsumexp_int_axis(): + get_queue_or_skip() + + x = dpt.zeros((8, 10), dtype="f4") + res = dpt.logsumexp(x, axis=0) + assert res.ndim == 1 + assert res.shape[0] == 10 + + +def test_logsumexp_invalid_arr(): + x = dict() + with pytest.raises(TypeError): + dpt.logsumexp(x) + + +@pytest.mark.parametrize("arg_dtype", _no_complex_dtypes[1:]) +def test_hypot_arg_dtype_default_output_dtype_matrix(arg_dtype): + q = get_queue_or_skip() + skip_if_dtype_not_supported(arg_dtype, q) + + m = dpt.ones(100, dtype=arg_dtype) + r = dpt.reduce_hypot(m) + + assert isinstance(r, dpt.usm_ndarray) + assert r.dtype.kind == "f" + tol = dpt.finfo(r.dtype).resolution + assert_allclose( + dpt.asnumpy(r), + np.hypot.reduce(dpt.asnumpy(m), dtype=r.dtype), + rtol=tol, + atol=tol, + ) + + +def test_hypot_empty(): + get_queue_or_skip() + x = dpt.empty((0,), dtype="f4") + y = dpt.reduce_hypot(x) + assert y.shape == tuple() + assert y == 0 + + +@pytest.mark.parametrize("arg_dtype", _no_complex_dtypes[1:]) +@pytest.mark.parametrize("out_dtype", _all_dtypes[1:]) +def test_hypot_arg_out_dtype_matrix(arg_dtype, out_dtype): + q = get_queue_or_skip() + skip_if_dtype_not_supported(arg_dtype, q) + skip_if_dtype_not_supported(out_dtype, q) + + m = dpt.ones(100, dtype=arg_dtype) + r = dpt.reduce_hypot(m, dtype=out_dtype) + + assert isinstance(r, dpt.usm_ndarray) + assert r.dtype == dpt.dtype(out_dtype) + + +def test_hypot_complex(): + get_queue_or_skip() + + x = dpt.zeros(1, dtype="c8") + with pytest.raises(TypeError): + dpt.reduce_hypot(x)