Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update test_linalg.py to run on Iris Xe #1474

Merged
Merged
39 changes: 35 additions & 4 deletions dpnp/backend/kernels/dpnp_krnl_common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -958,6 +958,13 @@ DPCTLSyclEventRef (*dpnp_matmul_ext_c)(DPCTLSyclQueueRef,
const DPCTLEventVectorRef) =
dpnp_matmul_c<_DataType>;

template <typename has_fp64 = std::true_type>
static constexpr DPNPFuncType get_res_type_with_aspect()
vlad-perevezentsev marked this conversation as resolved.
Show resolved Hide resolved
{
return has_fp64::value ? DPNPFuncType::DPNP_FT_DOUBLE
: DPNPFuncType::DPNP_FT_FLOAT;
}
vlad-perevezentsev marked this conversation as resolved.
Show resolved Hide resolved

void func_map_init_linalg(func_map_t &fmap)
{
fmap[DPNPFuncName::DPNP_FN_ASTYPE][eft_BLN][eft_BLN] = {
Expand Down Expand Up @@ -1152,9 +1159,21 @@ void func_map_init_linalg(func_map_t &fmap)
eft_DBL, (void *)dpnp_eig_default_c<double, double>};

fmap[DPNPFuncName::DPNP_FN_EIG_EXT][eft_INT][eft_INT] = {
eft_DBL, (void *)dpnp_eig_ext_c<int32_t, double>};
get_res_type_with_aspect<>(),
(void *)dpnp_eig_ext_c<
int32_t, func_type_map_t::find_type<get_res_type_with_aspect<>()>>,
get_res_type_with_aspect<std::false_type>(),
(void *)dpnp_eig_ext_c<
int32_t, func_type_map_t::find_type<
get_res_type_with_aspect<std::false_type>()>>};
fmap[DPNPFuncName::DPNP_FN_EIG_EXT][eft_LNG][eft_LNG] = {
eft_DBL, (void *)dpnp_eig_ext_c<int64_t, double>};
get_res_type_with_aspect<>(),
(void *)dpnp_eig_ext_c<
int64_t, func_type_map_t::find_type<get_res_type_with_aspect<>()>>,
get_res_type_with_aspect<std::false_type>(),
(void *)dpnp_eig_ext_c<
int64_t, func_type_map_t::find_type<
get_res_type_with_aspect<std::false_type>()>>};
fmap[DPNPFuncName::DPNP_FN_EIG_EXT][eft_FLT][eft_FLT] = {
eft_FLT, (void *)dpnp_eig_ext_c<float, float>};
fmap[DPNPFuncName::DPNP_FN_EIG_EXT][eft_DBL][eft_DBL] = {
Expand All @@ -1170,9 +1189,21 @@ void func_map_init_linalg(func_map_t &fmap)
eft_DBL, (void *)dpnp_eigvals_default_c<double, double>};

fmap[DPNPFuncName::DPNP_FN_EIGVALS_EXT][eft_INT][eft_INT] = {
eft_DBL, (void *)dpnp_eigvals_ext_c<int32_t, double>};
get_res_type_with_aspect<>(),
(void *)dpnp_eigvals_ext_c<
int32_t, func_type_map_t::find_type<get_res_type_with_aspect<>()>>,
get_res_type_with_aspect<std::false_type>(),
(void *)dpnp_eigvals_ext_c<
int32_t, func_type_map_t::find_type<
get_res_type_with_aspect<std::false_type>()>>};
fmap[DPNPFuncName::DPNP_FN_EIGVALS_EXT][eft_LNG][eft_LNG] = {
eft_DBL, (void *)dpnp_eigvals_ext_c<int64_t, double>};
get_res_type_with_aspect<>(),
(void *)dpnp_eigvals_ext_c<
int64_t, func_type_map_t::find_type<get_res_type_with_aspect<>()>>,
get_res_type_with_aspect<std::false_type>(),
(void *)dpnp_eigvals_ext_c<
int64_t, func_type_map_t::find_type<
get_res_type_with_aspect<std::false_type>()>>};
fmap[DPNPFuncName::DPNP_FN_EIGVALS_EXT][eft_FLT][eft_FLT] = {
eft_FLT, (void *)dpnp_eigvals_ext_c<float, float>};
fmap[DPNPFuncName::DPNP_FN_EIGVALS_EXT][eft_DBL][eft_DBL] = {
Expand Down
71 changes: 64 additions & 7 deletions dpnp/backend/kernels/dpnp_krnl_linalg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -839,6 +839,13 @@ DPCTLSyclEventRef (*dpnp_svd_ext_c)(DPCTLSyclQueueRef,
const DPCTLEventVectorRef) =
dpnp_svd_c<_InputDT, _ComputeDT, _SVDT>;

template <typename has_fp64 = std::true_type>
static constexpr DPNPFuncType get_res_type_with_aspect()
{
return has_fp64::value ? DPNPFuncType::DPNP_FT_DOUBLE
: DPNPFuncType::DPNP_FT_FLOAT;
}

void func_map_init_linalg_func(func_map_t &fmap)
{
fmap[DPNPFuncName::DPNP_FN_CHOLESKY][eft_FLT][eft_FLT] = {
Expand Down Expand Up @@ -879,11 +886,29 @@ void func_map_init_linalg_func(func_map_t &fmap)
eft_DBL, (void *)dpnp_inv_default_c<double, double>};

fmap[DPNPFuncName::DPNP_FN_INV_EXT][eft_INT][eft_INT] = {
eft_DBL, (void *)dpnp_inv_ext_c<int32_t, double>};
get_res_type_with_aspect<>(),
(void *)dpnp_inv_ext_c<
int32_t, func_type_map_t::find_type<get_res_type_with_aspect<>()>>,
get_res_type_with_aspect<std::false_type>(),
(void *)dpnp_inv_ext_c<
int32_t, func_type_map_t::find_type<
get_res_type_with_aspect<std::false_type>()>>};
fmap[DPNPFuncName::DPNP_FN_INV_EXT][eft_LNG][eft_LNG] = {
eft_DBL, (void *)dpnp_inv_ext_c<int64_t, double>};
get_res_type_with_aspect<>(),
(void *)dpnp_inv_ext_c<
int64_t, func_type_map_t::find_type<get_res_type_with_aspect<>()>>,
get_res_type_with_aspect<std::false_type>(),
(void *)dpnp_inv_ext_c<
int64_t, func_type_map_t::find_type<
get_res_type_with_aspect<std::false_type>()>>};
fmap[DPNPFuncName::DPNP_FN_INV_EXT][eft_FLT][eft_FLT] = {
eft_DBL, (void *)dpnp_inv_ext_c<float, double>};
get_res_type_with_aspect<>(),
vlad-perevezentsev marked this conversation as resolved.
Show resolved Hide resolved
(void *)dpnp_inv_ext_c<
float, func_type_map_t::find_type<get_res_type_with_aspect<>()>>,
get_res_type_with_aspect<std::false_type>(),
(void *)dpnp_inv_ext_c<
float, func_type_map_t::find_type<
get_res_type_with_aspect<std::false_type>()>>};
fmap[DPNPFuncName::DPNP_FN_INV_EXT][eft_DBL][eft_DBL] = {
eft_DBL, (void *)dpnp_inv_ext_c<double, double>};

Expand Down Expand Up @@ -1039,9 +1064,21 @@ void func_map_init_linalg_func(func_map_t &fmap)
// eft_C128, (void*)dpnp_qr_c<std::complex<double>, std::complex<double>>};

fmap[DPNPFuncName::DPNP_FN_QR_EXT][eft_INT][eft_INT] = {
eft_DBL, (void *)dpnp_qr_ext_c<int32_t, double>};
get_res_type_with_aspect<>(),
(void *)dpnp_qr_ext_c<
int32_t, func_type_map_t::find_type<get_res_type_with_aspect<>()>>,
get_res_type_with_aspect<std::false_type>(),
(void *)dpnp_qr_ext_c<
int32_t, func_type_map_t::find_type<
get_res_type_with_aspect<std::false_type>()>>};
fmap[DPNPFuncName::DPNP_FN_QR_EXT][eft_LNG][eft_LNG] = {
eft_DBL, (void *)dpnp_qr_ext_c<int64_t, double>};
get_res_type_with_aspect<>(),
(void *)dpnp_qr_ext_c<
int64_t, func_type_map_t::find_type<get_res_type_with_aspect<>()>>,
get_res_type_with_aspect<std::false_type>(),
(void *)dpnp_qr_ext_c<
int64_t, func_type_map_t::find_type<
get_res_type_with_aspect<std::false_type>()>>};
fmap[DPNPFuncName::DPNP_FN_QR_EXT][eft_FLT][eft_FLT] = {
eft_FLT, (void *)dpnp_qr_ext_c<float, float>};
fmap[DPNPFuncName::DPNP_FN_QR_EXT][eft_DBL][eft_DBL] = {
Expand All @@ -1062,9 +1099,29 @@ void func_map_init_linalg_func(func_map_t &fmap)
std::complex<double>, double>};

fmap[DPNPFuncName::DPNP_FN_SVD_EXT][eft_INT][eft_INT] = {
eft_DBL, (void *)dpnp_svd_ext_c<int32_t, double, double>};
get_res_type_with_aspect<>(),
(void *)dpnp_svd_ext_c<
int32_t, func_type_map_t::find_type<get_res_type_with_aspect<>()>,
func_type_map_t::find_type<get_res_type_with_aspect<>()>>,
get_res_type_with_aspect<std::false_type>(),
(void *)
dpnp_svd_ext_c<int32_t,
func_type_map_t::find_type<
get_res_type_with_aspect<std::false_type>()>,
func_type_map_t::find_type<
get_res_type_with_aspect<std::false_type>()>>};
fmap[DPNPFuncName::DPNP_FN_SVD_EXT][eft_LNG][eft_LNG] = {
eft_DBL, (void *)dpnp_svd_ext_c<int64_t, double, double>};
get_res_type_with_aspect<>(),
(void *)dpnp_svd_ext_c<
int64_t, func_type_map_t::find_type<get_res_type_with_aspect<>()>,
func_type_map_t::find_type<get_res_type_with_aspect<>()>>,
get_res_type_with_aspect<std::false_type>(),
(void *)
dpnp_svd_ext_c<int64_t,
func_type_map_t::find_type<
get_res_type_with_aspect<std::false_type>()>,
func_type_map_t::find_type<
get_res_type_with_aspect<std::false_type>()>>};
fmap[DPNPFuncName::DPNP_FN_SVD_EXT][eft_FLT][eft_FLT] = {
eft_FLT, (void *)dpnp_svd_ext_c<float, float, float>};
fmap[DPNPFuncName::DPNP_FN_SVD_EXT][eft_DBL][eft_DBL] = {
Expand Down
2 changes: 2 additions & 0 deletions dpnp/dpnp_algo/dpnp_algo.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -336,6 +336,8 @@ cdef extern from "dpnp_iface_fptr.hpp":
struct DPNPFuncData:
DPNPFuncType return_type
void * ptr
DPNPFuncType return_type_no_fp64
void *ptr_no_fp64

DPNPFuncData get_dpnp_function_ptr(DPNPFuncName name, DPNPFuncType first_type, DPNPFuncType second_type) except +

Expand Down
Loading