Skip to content

Commit

Permalink
Merge pull request #1179 from IntelPython/refactoring/dpctl-tensor-ty…
Browse files Browse the repository at this point in the history
…pe-dispatch-namespace

Refactoring/dpctl tensor type dispatch namespace
  • Loading branch information
oleksandr-pavlyk authored Apr 21, 2023
2 parents 4e41318 + 0ab2223 commit e7fc039
Show file tree
Hide file tree
Showing 12 changed files with 87 additions and 104 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -763,10 +763,6 @@ sycl::event masked_place_some_slices_strided_impl(
return comp_ev;
}

static masked_place_all_slices_strided_impl_fn_ptr_t
masked_place_all_slices_strided_impl_dispatch_vector
[dpctl::tensor::detail::num_types];

template <typename fnT, typename T> struct MaskPlaceAllSlicesStridedFactory
{
fnT get()
Expand All @@ -776,10 +772,6 @@ template <typename fnT, typename T> struct MaskPlaceAllSlicesStridedFactory
}
};

static masked_place_some_slices_strided_impl_fn_ptr_t
masked_place_some_slices_strided_impl_dispatch_vector
[dpctl::tensor::detail::num_types];

template <typename fnT, typename T> struct MaskPlaceSomeSlicesStridedFactory
{
fnT get()
Expand Down
6 changes: 3 additions & 3 deletions dpctl/tensor/libtensor/include/utils/type_dispatch.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ namespace dpctl
namespace tensor
{

namespace detail
namespace type_dispatch
{

enum class typenum_t : int
Expand Down Expand Up @@ -164,7 +164,7 @@ struct usm_ndarray_types

int typenum_to_lookup_id(int typenum) const
{
using typenum_t = dpctl::tensor::detail::typenum_t;
using typenum_t = ::dpctl::tensor::type_dispatch::typenum_t;
auto const &api = ::dpctl::detail::dpctl_capi::get();

if (typenum == api.UAR_DOUBLE_) {
Expand Down Expand Up @@ -250,7 +250,7 @@ struct usm_ndarray_types
}
};

} // namespace detail
} // namespace type_dispatch

} // namespace tensor
} // namespace dpctl
68 changes: 30 additions & 38 deletions dpctl/tensor/libtensor/source/boolean_advanced_indexing.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -93,27 +93,27 @@ void _split_iteration_space(const shT &shape_vec,

// Computation of positions of masked elements

namespace td_ns = dpctl::tensor::type_dispatch;

using dpctl::tensor::kernels::indexing::mask_positions_contig_impl_fn_ptr_t;
static mask_positions_contig_impl_fn_ptr_t
mask_positions_contig_dispatch_vector[dpctl::tensor::detail::num_types];
mask_positions_contig_dispatch_vector[td_ns::num_types];

using dpctl::tensor::kernels::indexing::mask_positions_strided_impl_fn_ptr_t;
static mask_positions_strided_impl_fn_ptr_t
mask_positions_strided_dispatch_vector[dpctl::tensor::detail::num_types];
mask_positions_strided_dispatch_vector[td_ns::num_types];

void populate_mask_positions_dispatch_vectors(void)
{
using dpctl::tensor::kernels::indexing::MaskPositionsContigFactory;
dpctl::tensor::detail::DispatchVectorBuilder<
mask_positions_contig_impl_fn_ptr_t, MaskPositionsContigFactory,
dpctl::tensor::detail::num_types>
td_ns::DispatchVectorBuilder<mask_positions_contig_impl_fn_ptr_t,
MaskPositionsContigFactory, td_ns::num_types>
dvb1;
dvb1.populate_dispatch_vector(mask_positions_contig_dispatch_vector);

using dpctl::tensor::kernels::indexing::MaskPositionsStridedFactory;
dpctl::tensor::detail::DispatchVectorBuilder<
mask_positions_strided_impl_fn_ptr_t, MaskPositionsStridedFactory,
dpctl::tensor::detail::num_types>
td_ns::DispatchVectorBuilder<mask_positions_strided_impl_fn_ptr_t,
MaskPositionsStridedFactory, td_ns::num_types>
dvb2;
dvb2.populate_dispatch_vector(mask_positions_strided_dispatch_vector);

Expand Down Expand Up @@ -158,14 +158,13 @@ size_t py_mask_positions(dpctl::tensor::usm_ndarray mask,
const char *mask_data = mask.get_data();
char *cumsum_data = cumsum.get_data();

auto const &array_types = dpctl::tensor::detail::usm_ndarray_types();
auto const &array_types = td_ns::usm_ndarray_types();

int mask_typeid = array_types.typenum_to_lookup_id(mask_typenum);
int cumsum_typeid = array_types.typenum_to_lookup_id(cumsum_typenum);

// cumsum must be int64_t only
constexpr int int64_typeid =
static_cast<int>(dpctl::tensor::detail::typenum_t::INT64);
constexpr int int64_typeid = static_cast<int>(td_ns::typenum_t::INT64);
if (cumsum_typeid != int64_typeid) {
throw py::value_error(
"Cumulative sum array must have int64 data-type.");
Expand Down Expand Up @@ -244,30 +243,28 @@ using dpctl::tensor::kernels::indexing::
masked_extract_all_slices_strided_impl_fn_ptr_t;

static masked_extract_all_slices_strided_impl_fn_ptr_t
masked_extract_all_slices_strided_impl_dispatch_vector
[dpctl::tensor::detail::num_types];
masked_extract_all_slices_strided_impl_dispatch_vector[td_ns::num_types];

using dpctl::tensor::kernels::indexing::
masked_extract_some_slices_strided_impl_fn_ptr_t;

static masked_extract_some_slices_strided_impl_fn_ptr_t
masked_extract_some_slices_strided_impl_dispatch_vector
[dpctl::tensor::detail::num_types];
masked_extract_some_slices_strided_impl_dispatch_vector[td_ns::num_types];

void populate_masked_extract_dispatch_vectors(void)
{
using dpctl::tensor::kernels::indexing::MaskExtractAllSlicesStridedFactory;
dpctl::tensor::detail::DispatchVectorBuilder<
td_ns::DispatchVectorBuilder<
masked_extract_all_slices_strided_impl_fn_ptr_t,
MaskExtractAllSlicesStridedFactory, dpctl::tensor::detail::num_types>
MaskExtractAllSlicesStridedFactory, td_ns::num_types>
dvb1;
dvb1.populate_dispatch_vector(
masked_extract_all_slices_strided_impl_dispatch_vector);

using dpctl::tensor::kernels::indexing::MaskExtractSomeSlicesStridedFactory;
dpctl::tensor::detail::DispatchVectorBuilder<
td_ns::DispatchVectorBuilder<
masked_extract_some_slices_strided_impl_fn_ptr_t,
MaskExtractSomeSlicesStridedFactory, dpctl::tensor::detail::num_types>
MaskExtractSomeSlicesStridedFactory, td_ns::num_types>
dvb2;
dvb2.populate_dispatch_vector(
masked_extract_some_slices_strided_impl_dispatch_vector);
Expand Down Expand Up @@ -359,13 +356,12 @@ py_extract(dpctl::tensor::usm_ndarray src,
int dst_typenum = dst.get_typenum();
int cumsum_typenum = cumsum.get_typenum();

auto const &array_types = dpctl::tensor::detail::usm_ndarray_types();
auto const &array_types = td_ns::usm_ndarray_types();
int src_typeid = array_types.typenum_to_lookup_id(src_typenum);
int dst_typeid = array_types.typenum_to_lookup_id(dst_typenum);
int cumsum_typeid = array_types.typenum_to_lookup_id(cumsum_typenum);

constexpr int int64_typeid =
static_cast<int>(dpctl::tensor::detail::typenum_t::INT64);
constexpr int int64_typeid = static_cast<int>(td_ns::typenum_t::INT64);
if (cumsum_typeid != int64_typeid) {
throw py::value_error(
"Unexact data type of cumsum array, expecting 'int64'");
Expand Down Expand Up @@ -557,30 +553,28 @@ using dpctl::tensor::kernels::indexing::
masked_place_all_slices_strided_impl_fn_ptr_t;

static masked_place_all_slices_strided_impl_fn_ptr_t
masked_place_all_slices_strided_impl_dispatch_vector
[dpctl::tensor::detail::num_types];
masked_place_all_slices_strided_impl_dispatch_vector[td_ns::num_types];

using dpctl::tensor::kernels::indexing::
masked_place_some_slices_strided_impl_fn_ptr_t;

static masked_place_some_slices_strided_impl_fn_ptr_t
masked_place_some_slices_strided_impl_dispatch_vector
[dpctl::tensor::detail::num_types];
masked_place_some_slices_strided_impl_dispatch_vector[td_ns::num_types];

void populate_masked_place_dispatch_vectors(void)
{
using dpctl::tensor::kernels::indexing::MaskPlaceAllSlicesStridedFactory;
dpctl::tensor::detail::DispatchVectorBuilder<
masked_place_all_slices_strided_impl_fn_ptr_t,
MaskPlaceAllSlicesStridedFactory, dpctl::tensor::detail::num_types>
td_ns::DispatchVectorBuilder<masked_place_all_slices_strided_impl_fn_ptr_t,
MaskPlaceAllSlicesStridedFactory,
td_ns::num_types>
dvb1;
dvb1.populate_dispatch_vector(
masked_place_all_slices_strided_impl_dispatch_vector);

using dpctl::tensor::kernels::indexing::MaskPlaceSomeSlicesStridedFactory;
dpctl::tensor::detail::DispatchVectorBuilder<
masked_place_some_slices_strided_impl_fn_ptr_t,
MaskPlaceSomeSlicesStridedFactory, dpctl::tensor::detail::num_types>
td_ns::DispatchVectorBuilder<masked_place_some_slices_strided_impl_fn_ptr_t,
MaskPlaceSomeSlicesStridedFactory,
td_ns::num_types>
dvb2;
dvb2.populate_dispatch_vector(
masked_place_some_slices_strided_impl_dispatch_vector);
Expand Down Expand Up @@ -673,13 +667,12 @@ py_place(dpctl::tensor::usm_ndarray dst,
int rhs_typenum = rhs.get_typenum();
int cumsum_typenum = cumsum.get_typenum();

auto const &array_types = dpctl::tensor::detail::usm_ndarray_types();
auto const &array_types = td_ns::usm_ndarray_types();
int dst_typeid = array_types.typenum_to_lookup_id(dst_typenum);
int rhs_typeid = array_types.typenum_to_lookup_id(rhs_typenum);
int cumsum_typeid = array_types.typenum_to_lookup_id(cumsum_typenum);

constexpr int int64_typeid =
static_cast<int>(dpctl::tensor::detail::typenum_t::INT64);
constexpr int int64_typeid = static_cast<int>(td_ns::typenum_t::INT64);
if (cumsum_typeid != int64_typeid) {
throw py::value_error(
"Unexact data type of cumsum array, expecting 'int64'");
Expand Down Expand Up @@ -913,15 +906,14 @@ std::pair<sycl::event, sycl::event> py_nonzero(
py::ssize_t nz_elems = indexes_shape[1];

int indexes_typenum = indexes.get_typenum();
auto const &array_types = dpctl::tensor::detail::usm_ndarray_types();
auto const &array_types = td_ns::usm_ndarray_types();
int indexes_typeid = array_types.typenum_to_lookup_id(indexes_typenum);

int cumsum_typenum = cumsum.get_typenum();
int cumsum_typeid = array_types.typenum_to_lookup_id(cumsum_typenum);

// cumsum must be int64_t only
constexpr int int64_typeid =
static_cast<int>(dpctl::tensor::detail::typenum_t::INT64);
constexpr int int64_typeid = static_cast<int>(td_ns::typenum_t::INT64);
if (cumsum_typeid != int64_typeid || indexes_typeid != int64_typeid) {
throw py::value_error(
"Cumulative sum array and index array must have int64 data-type");
Expand Down
12 changes: 6 additions & 6 deletions dpctl/tensor/libtensor/source/copy_and_cast_usm_to_usm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,18 +49,18 @@ namespace tensor
namespace py_internal
{

namespace _ns = dpctl::tensor::detail;
namespace td_ns = dpctl::tensor::type_dispatch;

using dpctl::tensor::kernels::copy_and_cast::copy_and_cast_1d_fn_ptr_t;
using dpctl::tensor::kernels::copy_and_cast::copy_and_cast_contig_fn_ptr_t;
using dpctl::tensor::kernels::copy_and_cast::copy_and_cast_generic_fn_ptr_t;

static copy_and_cast_generic_fn_ptr_t
copy_and_cast_generic_dispatch_table[_ns::num_types][_ns::num_types];
copy_and_cast_generic_dispatch_table[td_ns::num_types][td_ns::num_types];
static copy_and_cast_1d_fn_ptr_t
copy_and_cast_1d_dispatch_table[_ns::num_types][_ns::num_types];
copy_and_cast_1d_dispatch_table[td_ns::num_types][td_ns::num_types];
static copy_and_cast_contig_fn_ptr_t
copy_and_cast_contig_dispatch_table[_ns::num_types][_ns::num_types];
copy_and_cast_contig_dispatch_table[td_ns::num_types][td_ns::num_types];

namespace py = pybind11;

Expand Down Expand Up @@ -121,7 +121,7 @@ copy_usm_ndarray_into_usm_ndarray(dpctl::tensor::usm_ndarray src,
int src_typenum = src.get_typenum();
int dst_typenum = dst.get_typenum();

auto array_types = dpctl::tensor::detail::usm_ndarray_types();
auto array_types = td_ns::usm_ndarray_types();
int src_type_id = array_types.typenum_to_lookup_id(src_typenum);
int dst_type_id = array_types.typenum_to_lookup_id(dst_typenum);

Expand Down Expand Up @@ -277,7 +277,7 @@ copy_usm_ndarray_into_usm_ndarray(dpctl::tensor::usm_ndarray src,

void init_copy_and_cast_usm_to_usm_dispatch_tables(void)
{
using namespace dpctl::tensor::detail;
using namespace td_ns;

using dpctl::tensor::kernels::copy_and_cast::CopyAndCastContigFactory;
DispatchTableBuilder<copy_and_cast_contig_fn_ptr_t,
Expand Down
8 changes: 4 additions & 4 deletions dpctl/tensor/libtensor/source/copy_for_reshape.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,14 +39,14 @@ namespace tensor
namespace py_internal
{

namespace _ns = dpctl::tensor::detail;
namespace td_ns = dpctl::tensor::type_dispatch;

using dpctl::tensor::kernels::copy_and_cast::copy_for_reshape_fn_ptr_t;
using dpctl::utils::keep_args_alive;

// define static vector
static copy_for_reshape_fn_ptr_t
copy_for_reshape_generic_dispatch_vector[_ns::num_types];
copy_for_reshape_generic_dispatch_vector[td_ns::num_types];

/*
* Copies src into dst (same data type) of different shapes by using flat
Expand Down Expand Up @@ -121,7 +121,7 @@ copy_usm_ndarray_for_reshape(dpctl::tensor::usm_ndarray src,
int src_nd = src.get_ndim();
int dst_nd = dst.get_ndim();

auto array_types = dpctl::tensor::detail::usm_ndarray_types();
auto array_types = td_ns::usm_ndarray_types();
int type_id = array_types.typenum_to_lookup_id(src_typenum);

auto fn = copy_for_reshape_generic_dispatch_vector[type_id];
Expand Down Expand Up @@ -172,7 +172,7 @@ copy_usm_ndarray_for_reshape(dpctl::tensor::usm_ndarray src,

void init_copy_for_reshape_dispatch_vectors(void)
{
using namespace dpctl::tensor::detail;
using namespace td_ns;
using dpctl::tensor::kernels::copy_and_cast::CopyForReshapeGenericFactory;

DispatchVectorBuilder<copy_for_reshape_fn_ptr_t,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
#include "simplify_iteration_space.hpp"

namespace py = pybind11;
namespace _ns = dpctl::tensor::detail;
namespace td_ns = dpctl::tensor::type_dispatch;

namespace dpctl
{
Expand All @@ -49,8 +49,8 @@ using dpctl::tensor::kernels::copy_and_cast::
copy_and_cast_from_host_blocking_fn_ptr_t;

static copy_and_cast_from_host_blocking_fn_ptr_t
copy_and_cast_from_host_blocking_dispatch_table[_ns::num_types]
[_ns::num_types];
copy_and_cast_from_host_blocking_dispatch_table[td_ns::num_types]
[td_ns::num_types];

void copy_numpy_ndarray_into_usm_ndarray(
py::array npy_src,
Expand Down Expand Up @@ -111,7 +111,7 @@ void copy_numpy_ndarray_into_usm_ndarray(
py::detail::array_descriptor_proxy(npy_src.dtype().ptr())->type_num;
int dst_typenum = dst.get_typenum();

auto array_types = dpctl::tensor::detail::usm_ndarray_types();
auto array_types = td_ns::usm_ndarray_types();
int src_type_id = array_types.typenum_to_lookup_id(src_typenum);
int dst_type_id = array_types.typenum_to_lookup_id(dst_typenum);

Expand Down Expand Up @@ -239,11 +239,11 @@ void copy_numpy_ndarray_into_usm_ndarray(

void init_copy_numpy_ndarray_into_usm_ndarray_dispatch_tables(void)
{
using namespace dpctl::tensor::detail;
using namespace td_ns;
using dpctl::tensor::kernels::copy_and_cast::CopyAndCastFromHostFactory;

DispatchTableBuilder<copy_and_cast_from_host_blocking_fn_ptr_t,
CopyAndCastFromHostFactory, _ns::num_types>
CopyAndCastFromHostFactory, num_types>
dtb_copy_from_numpy;

dtb_copy_from_numpy.populate_dispatch_table(
Expand Down
8 changes: 4 additions & 4 deletions dpctl/tensor/libtensor/source/eye_ctor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
#include "utils/type_dispatch.hpp"

namespace py = pybind11;
namespace _ns = dpctl::tensor::detail;
namespace td_ns = dpctl::tensor::type_dispatch;

namespace dpctl
{
Expand All @@ -46,7 +46,7 @@ namespace py_internal
using dpctl::utils::keep_args_alive;

using dpctl::tensor::kernels::constructors::eye_fn_ptr_t;
static eye_fn_ptr_t eye_dispatch_vector[_ns::num_types];
static eye_fn_ptr_t eye_dispatch_vector[td_ns::num_types];

std::pair<sycl::event, sycl::event>
usm_ndarray_eye(py::ssize_t k,
Expand All @@ -66,7 +66,7 @@ usm_ndarray_eye(py::ssize_t k,
"allocation queue");
}

auto array_types = dpctl::tensor::detail::usm_ndarray_types();
auto array_types = td_ns::usm_ndarray_types();
int dst_typenum = dst.get_typenum();
int dst_typeid = array_types.typenum_to_lookup_id(dst_typenum);

Expand Down Expand Up @@ -118,7 +118,7 @@ usm_ndarray_eye(py::ssize_t k,

void init_eye_ctor_dispatch_vectors(void)
{
using namespace dpctl::tensor::detail;
using namespace td_ns;
using dpctl::tensor::kernels::constructors::EyeFactory;

DispatchVectorBuilder<eye_fn_ptr_t, EyeFactory, num_types> dvb;
Expand Down
Loading

0 comments on commit e7fc039

Please sign in to comment.