Skip to content

Commit

Permalink
Simpler dispatching for in-place broadcast kernels and changes reques…
Browse files Browse the repository at this point in the history
…ted by @vtavana
  • Loading branch information
ndgrigorian committed Jun 14, 2023
1 parent 7b82b2b commit 2e7cbe0
Show file tree
Hide file tree
Showing 12 changed files with 41 additions and 46 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,7 @@ template <typename fnT, typename T1, typename T2> struct AddTypeMapFactory
};

template <typename T1, typename T2, typename resT, typename IndexerT>
class add_strided_strided_kernel;
class add_strided_kernel;

template <typename argTy1, typename argTy2>
sycl::event add_strided_impl(sycl::queue exec_q,
Expand All @@ -235,8 +235,7 @@ sycl::event add_strided_impl(sycl::queue exec_q,
const std::vector<sycl::event> &additional_depends)
{
return elementwise_common::binary_strided_impl<
argTy1, argTy2, AddOutputType, AddStridedFunctor,
add_strided_strided_kernel>(
argTy1, argTy2, AddOutputType, AddStridedFunctor, add_strided_kernel>(
exec_q, nelems, nd, shape_and_strides, arg1_p, arg1_offset, arg2_p,
arg2_offset, res_p, res_offset, depends, additional_depends);
}
Expand Down Expand Up @@ -515,14 +514,13 @@ struct AddInplaceRowMatrixBroadcastFactory
fnT get()
{
using resT = typename AddOutputType<T1, T2>::value_type;
if constexpr (std::is_same_v<resT, void>) {
if constexpr (!std::is_same_v<resT, T2>) {
fnT fn = nullptr;
return fn;
}
else {
if constexpr (dpctl::tensor::type_utils::is_complex<T1>::value ||
dpctl::tensor::type_utils::is_complex<T2>::value ||
dpctl::tensor::type_utils::is_complex<resT>::value)
dpctl::tensor::type_utils::is_complex<T2>::value)
{
fnT fn = nullptr;
return fn;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,7 @@ template <typename fnT, typename T1, typename T2> struct EqualTypeMapFactory
};

template <typename T1, typename T2, typename resT, typename IndexerT>
class equal_strided_strided_kernel;
class equal_strided_kernel;

template <typename argTy1, typename argTy2>
sycl::event
Expand All @@ -220,9 +220,9 @@ equal_strided_impl(sycl::queue exec_q,
{
return elementwise_common::binary_strided_impl<
argTy1, argTy2, EqualOutputType, EqualStridedFunctor,
equal_strided_strided_kernel>(
exec_q, nelems, nd, shape_and_strides, arg1_p, arg1_offset, arg2_p,
arg2_offset, res_p, res_offset, depends, additional_depends);
equal_strided_kernel>(exec_q, nelems, nd, shape_and_strides, arg1_p,
arg1_offset, arg2_p, arg2_offset, res_p,
res_offset, depends, additional_depends);
}

template <typename fnT, typename T1, typename T2> struct EqualStridedFactory
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,7 @@ struct FloorDivideTypeMapFactory
};

template <typename T1, typename T2, typename resT, typename IndexerT>
class floor_divide_strided_strided_kernel;
class floor_divide_strided_kernel;

template <typename argTy1, typename argTy2>
sycl::event
Expand All @@ -254,7 +254,7 @@ floor_divide_strided_impl(sycl::queue exec_q,
{
return elementwise_common::binary_strided_impl<
argTy1, argTy2, FloorDivideOutputType, FloorDivideStridedFunctor,
floor_divide_strided_strided_kernel>(
floor_divide_strided_kernel>(
exec_q, nelems, nd, shape_and_strides, arg1_p, arg1_offset, arg2_p,
arg2_offset, res_p, res_offset, depends, additional_depends);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -255,7 +255,7 @@ template <typename fnT, typename T1, typename T2> struct GreaterTypeMapFactory
};

template <typename T1, typename T2, typename resT, typename IndexerT>
class greater_strided_strided_kernel;
class greater_strided_kernel;

template <typename argTy1, typename argTy2>
sycl::event
Expand Down Expand Up @@ -289,7 +289,7 @@ greater_strided_impl(sycl::queue exec_q,
resTy *res_tp = reinterpret_cast<resTy *>(res_p);

cgh.parallel_for<
greater_strided_strided_kernel<argTy1, argTy2, resTy, IndexerT>>(
greater_strided_kernel<argTy1, argTy2, resTy, IndexerT>>(
{nelems}, GreaterStridedFunctor<argTy1, argTy2, resTy, IndexerT>(
arg1_tp, arg2_tp, res_tp, indexer));
});
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -261,7 +261,7 @@ struct GreaterEqualTypeMapFactory
};

template <typename T1, typename T2, typename resT, typename IndexerT>
class greater_equal_strided_strided_kernel;
class greater_equal_strided_kernel;

template <typename argTy1, typename argTy2>
sycl::event
Expand Down Expand Up @@ -295,8 +295,8 @@ greater_equal_strided_impl(sycl::queue exec_q,
const argTy2 *arg2_tp = reinterpret_cast<const argTy2 *>(arg2_p);
resTy *res_tp = reinterpret_cast<resTy *>(res_p);

cgh.parallel_for<greater_equal_strided_strided_kernel<argTy1, argTy2,
resTy, IndexerT>>(
cgh.parallel_for<
greater_equal_strided_kernel<argTy1, argTy2, resTy, IndexerT>>(
{nelems},
GreaterEqualStridedFunctor<argTy1, argTy2, resTy, IndexerT>(
arg1_tp, arg2_tp, res_tp, indexer));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -253,7 +253,7 @@ template <typename fnT, typename T1, typename T2> struct LessTypeMapFactory
};

template <typename T1, typename T2, typename resT, typename IndexerT>
class less_strided_strided_kernel;
class less_strided_kernel;

template <typename argTy1, typename argTy2>
sycl::event
Expand Down Expand Up @@ -286,8 +286,7 @@ less_strided_impl(sycl::queue exec_q,
const argTy2 *arg2_tp = reinterpret_cast<const argTy2 *>(arg2_p);
resTy *res_tp = reinterpret_cast<resTy *>(res_p);

cgh.parallel_for<
less_strided_strided_kernel<argTy1, argTy2, resTy, IndexerT>>(
cgh.parallel_for<less_strided_kernel<argTy1, argTy2, resTy, IndexerT>>(
{nelems}, LessStridedFunctor<argTy1, argTy2, resTy, IndexerT>(
arg1_tp, arg2_tp, res_tp, indexer));
});
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -256,7 +256,7 @@ template <typename fnT, typename T1, typename T2> struct LessEqualTypeMapFactory
};

template <typename T1, typename T2, typename resT, typename IndexerT>
class less_equal_strided_strided_kernel;
class less_equal_strided_kernel;

template <typename argTy1, typename argTy2>
sycl::event
Expand Down Expand Up @@ -290,7 +290,7 @@ less_equal_strided_impl(sycl::queue exec_q,
resTy *res_tp = reinterpret_cast<resTy *>(res_p);

cgh.parallel_for<
less_equal_strided_strided_kernel<argTy1, argTy2, resTy, IndexerT>>(
less_equal_strided_kernel<argTy1, argTy2, resTy, IndexerT>>(
{nelems}, LessEqualStridedFunctor<argTy1, argTy2, resTy, IndexerT>(
arg1_tp, arg2_tp, res_tp, indexer));
});
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,7 @@ template <typename fnT, typename T1, typename T2> struct MultiplyTypeMapFactory
};

template <typename T1, typename T2, typename resT, typename IndexerT>
class multiply_strided_strided_kernel;
class multiply_strided_kernel;

template <typename argTy1, typename argTy2>
sycl::event
Expand All @@ -240,9 +240,9 @@ multiply_strided_impl(sycl::queue exec_q,
{
return elementwise_common::binary_strided_impl<
argTy1, argTy2, MultiplyOutputType, MultiplyStridedFunctor,
multiply_strided_strided_kernel>(
exec_q, nelems, nd, shape_and_strides, arg1_p, arg1_offset, arg2_p,
arg2_offset, res_p, res_offset, depends, additional_depends);
multiply_strided_kernel>(exec_q, nelems, nd, shape_and_strides, arg1_p,
arg1_offset, arg2_p, arg2_offset, res_p,
res_offset, depends, additional_depends);
}

template <typename fnT, typename T1, typename T2> struct MultiplyStridedFactory
Expand Down Expand Up @@ -531,14 +531,13 @@ struct MultiplyInplaceRowMatrixBroadcastFactory
fnT get()
{
using resT = typename MultiplyOutputType<T1, T2>::value_type;
if constexpr (std::is_same_v<resT, void>) {
if constexpr (!std::is_same_v<resT, T2>) {
fnT fn = nullptr;
return fn;
}
else {
if constexpr (dpctl::tensor::type_utils::is_complex<T1>::value ||
dpctl::tensor::type_utils::is_complex<T2>::value ||
dpctl::tensor::type_utils::is_complex<resT>::value)
dpctl::tensor::type_utils::is_complex<T2>::value)
{
fnT fn = nullptr;
return fn;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,7 @@ template <typename fnT, typename T1, typename T2> struct NotEqualTypeMapFactory
};

template <typename T1, typename T2, typename resT, typename IndexerT>
class not_equal_strided_strided_kernel;
class not_equal_strided_kernel;

template <typename argTy1, typename argTy2>
sycl::event
Expand All @@ -237,9 +237,9 @@ not_equal_strided_impl(sycl::queue exec_q,
{
return elementwise_common::binary_strided_impl<
argTy1, argTy2, NotEqualOutputType, NotEqualStridedFunctor,
not_equal_strided_strided_kernel>(
exec_q, nelems, nd, shape_and_strides, arg1_p, arg1_offset, arg2_p,
arg2_offset, res_p, res_offset, depends, additional_depends);
not_equal_strided_kernel>(exec_q, nelems, nd, shape_and_strides, arg1_p,
arg1_offset, arg2_p, arg2_offset, res_p,
res_offset, depends, additional_depends);
}

template <typename fnT, typename T1, typename T2> struct NotEqualStridedFactory
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,7 @@ template <typename fnT, typename T1, typename T2> struct SubtractTypeMapFactory
};

template <typename T1, typename T2, typename resT, typename IndexerT>
class subtract_strided_strided_kernel;
class subtract_strided_kernel;

template <typename argTy1, typename argTy2>
sycl::event
Expand All @@ -237,9 +237,9 @@ subtract_strided_impl(sycl::queue exec_q,
{
return elementwise_common::binary_strided_impl<
argTy1, argTy2, SubtractOutputType, SubtractStridedFunctor,
subtract_strided_strided_kernel>(
exec_q, nelems, nd, shape_and_strides, arg1_p, arg1_offset, arg2_p,
arg2_offset, res_p, res_offset, depends, additional_depends);
subtract_strided_kernel>(exec_q, nelems, nd, shape_and_strides, arg1_p,
arg1_offset, arg2_p, arg2_offset, res_p,
res_offset, depends, additional_depends);
}

template <typename fnT, typename T1, typename T2> struct SubtractStridedFactory
Expand Down Expand Up @@ -544,14 +544,13 @@ struct SubtractInplaceRowMatrixBroadcastFactory
fnT get()
{
using resT = typename SubtractOutputType<T1, T2>::value_type;
if constexpr (std::is_same_v<resT, void>) {
if constexpr (!std::is_same_v<resT, T2>) {
fnT fn = nullptr;
return fn;
}
else {
if constexpr (dpctl::tensor::type_utils::is_complex<T1>::value ||
dpctl::tensor::type_utils::is_complex<T2>::value ||
dpctl::tensor::type_utils::is_complex<resT>::value)
dpctl::tensor::type_utils::is_complex<T2>::value)
{
fnT fn = nullptr;
return fn;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,7 @@ struct TrueDivideTypeMapFactory
};

template <typename T1, typename T2, typename resT, typename IndexerT>
class true_divide_strided_strided_kernel;
class true_divide_strided_kernel;

template <typename argTy1, typename argTy2>
sycl::event
Expand All @@ -220,7 +220,7 @@ true_divide_strided_impl(sycl::queue exec_q,
{
return elementwise_common::binary_strided_impl<
argTy1, argTy2, TrueDivideOutputType, TrueDivideStridedFunctor,
true_divide_strided_strided_kernel>(
true_divide_strided_kernel>(
exec_q, nelems, nd, shape_and_strides, arg1_p, arg1_offset, arg2_p,
arg2_offset, res_p, res_offset, depends, additional_depends);
}
Expand Down
8 changes: 4 additions & 4 deletions dpctl/tensor/libtensor/source/elementwise_functions.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -517,7 +517,7 @@ std::pair<sycl::event, sycl::event> py_binary_ufunc(

if (strided_fn == nullptr) {
throw std::runtime_error(
"Contiguous implementation is missing for src1_typeid=" +
"Strided implementation is missing for src1_typeid=" +
std::to_string(src1_typeid) +
" and src2_typeid=" + std::to_string(src2_typeid));
}
Expand Down Expand Up @@ -627,7 +627,7 @@ py_binary_inplace_ufunc(dpctl::tensor::usm_ndarray lhs,

if (output_typeid != lhs_typeid) {
throw py::value_error(
"Destination array has unexpected elemental data type.");
"Left-hand side array has unexpected elemental data type.");
}

// check that queues are compatible
Expand Down Expand Up @@ -696,7 +696,7 @@ py_binary_inplace_ufunc(dpctl::tensor::usm_ndarray lhs,

// dispatch for contiguous inputs
if (both_c_contig || both_f_contig) {
auto contig_fn = contig_dispatch_table[lhs_typeid][rhs_typeid];
auto contig_fn = contig_dispatch_table[rhs_typeid][lhs_typeid];

if (contig_fn != nullptr) {
auto comp_ev = contig_fn(exec_q, rhs_nelems, rhs_data, 0, lhs_data,
Expand Down Expand Up @@ -781,7 +781,7 @@ py_binary_inplace_ufunc(dpctl::tensor::usm_ndarray lhs,

if (strided_fn == nullptr) {
throw std::runtime_error(
"Contiguous implementation is missing for rhs_typeid=" +
"Strided implementation is missing for rhs_typeid=" +
std::to_string(rhs_typeid) +
" and lhs_typeid=" + std::to_string(lhs_typeid));
}
Expand Down

0 comments on commit 2e7cbe0

Please sign in to comment.