Skip to content

Commit

Permalink
Update test_linalg.py to run on Iris Xe (#1474)
Browse files Browse the repository at this point in the history
* Update cholesky function
* Update dpnp.linalg.eig function
* Update dpnp.linalg.eigvals
* Update dpnp.linalg.inv()
* Update dpnp_norm
* Update dpnp.linalg.qr
* Update dpnp.linalg.svd
* Rename and move get_res_type_with_aspect func
* dpnp_inv should return float when got float type
  • Loading branch information
vlad-perevezentsev authored Jul 13, 2023
1 parent cfac723 commit 771653b
Show file tree
Hide file tree
Showing 8 changed files with 180 additions and 149 deletions.
32 changes: 28 additions & 4 deletions dpnp/backend/kernels/dpnp_krnl_common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1152,9 +1152,21 @@ void func_map_init_linalg(func_map_t &fmap)
eft_DBL, (void *)dpnp_eig_default_c<double, double>};

fmap[DPNPFuncName::DPNP_FN_EIG_EXT][eft_INT][eft_INT] = {
eft_DBL, (void *)dpnp_eig_ext_c<int32_t, double>};
get_default_floating_type<>(),
(void *)dpnp_eig_ext_c<
int32_t, func_type_map_t::find_type<get_default_floating_type<>()>>,
get_default_floating_type<std::false_type>(),
(void *)dpnp_eig_ext_c<
int32_t, func_type_map_t::find_type<
get_default_floating_type<std::false_type>()>>};
fmap[DPNPFuncName::DPNP_FN_EIG_EXT][eft_LNG][eft_LNG] = {
eft_DBL, (void *)dpnp_eig_ext_c<int64_t, double>};
get_default_floating_type<>(),
(void *)dpnp_eig_ext_c<
int64_t, func_type_map_t::find_type<get_default_floating_type<>()>>,
get_default_floating_type<std::false_type>(),
(void *)dpnp_eig_ext_c<
int64_t, func_type_map_t::find_type<
get_default_floating_type<std::false_type>()>>};
fmap[DPNPFuncName::DPNP_FN_EIG_EXT][eft_FLT][eft_FLT] = {
eft_FLT, (void *)dpnp_eig_ext_c<float, float>};
fmap[DPNPFuncName::DPNP_FN_EIG_EXT][eft_DBL][eft_DBL] = {
Expand All @@ -1170,9 +1182,21 @@ void func_map_init_linalg(func_map_t &fmap)
eft_DBL, (void *)dpnp_eigvals_default_c<double, double>};

fmap[DPNPFuncName::DPNP_FN_EIGVALS_EXT][eft_INT][eft_INT] = {
eft_DBL, (void *)dpnp_eigvals_ext_c<int32_t, double>};
get_default_floating_type<>(),
(void *)dpnp_eigvals_ext_c<
int32_t, func_type_map_t::find_type<get_default_floating_type<>()>>,
get_default_floating_type<std::false_type>(),
(void *)dpnp_eigvals_ext_c<
int32_t, func_type_map_t::find_type<
get_default_floating_type<std::false_type>()>>};
fmap[DPNPFuncName::DPNP_FN_EIGVALS_EXT][eft_LNG][eft_LNG] = {
eft_DBL, (void *)dpnp_eigvals_ext_c<int64_t, double>};
get_default_floating_type<>(),
(void *)dpnp_eigvals_ext_c<
int64_t, func_type_map_t::find_type<get_default_floating_type<>()>>,
get_default_floating_type<std::false_type>(),
(void *)dpnp_eigvals_ext_c<
int64_t, func_type_map_t::find_type<
get_default_floating_type<std::false_type>()>>};
fmap[DPNPFuncName::DPNP_FN_EIGVALS_EXT][eft_FLT][eft_FLT] = {
eft_FLT, (void *)dpnp_eigvals_ext_c<float, float>};
fmap[DPNPFuncName::DPNP_FN_EIGVALS_EXT][eft_DBL][eft_DBL] = {
Expand Down
60 changes: 52 additions & 8 deletions dpnp/backend/kernels/dpnp_krnl_linalg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -874,16 +874,28 @@ void func_map_init_linalg_func(func_map_t &fmap)
fmap[DPNPFuncName::DPNP_FN_INV][eft_LNG][eft_LNG] = {
eft_DBL, (void *)dpnp_inv_default_c<int64_t, double>};
fmap[DPNPFuncName::DPNP_FN_INV][eft_FLT][eft_FLT] = {
eft_DBL, (void *)dpnp_inv_default_c<float, double>};
eft_DBL, (void *)dpnp_inv_default_c<float, float>};
fmap[DPNPFuncName::DPNP_FN_INV][eft_DBL][eft_DBL] = {
eft_DBL, (void *)dpnp_inv_default_c<double, double>};

fmap[DPNPFuncName::DPNP_FN_INV_EXT][eft_INT][eft_INT] = {
eft_DBL, (void *)dpnp_inv_ext_c<int32_t, double>};
get_default_floating_type<>(),
(void *)dpnp_inv_ext_c<
int32_t, func_type_map_t::find_type<get_default_floating_type<>()>>,
get_default_floating_type<std::false_type>(),
(void *)dpnp_inv_ext_c<
int32_t, func_type_map_t::find_type<
get_default_floating_type<std::false_type>()>>};
fmap[DPNPFuncName::DPNP_FN_INV_EXT][eft_LNG][eft_LNG] = {
eft_DBL, (void *)dpnp_inv_ext_c<int64_t, double>};
get_default_floating_type<>(),
(void *)dpnp_inv_ext_c<
int64_t, func_type_map_t::find_type<get_default_floating_type<>()>>,
get_default_floating_type<std::false_type>(),
(void *)dpnp_inv_ext_c<
int64_t, func_type_map_t::find_type<
get_default_floating_type<std::false_type>()>>};
fmap[DPNPFuncName::DPNP_FN_INV_EXT][eft_FLT][eft_FLT] = {
eft_DBL, (void *)dpnp_inv_ext_c<float, double>};
eft_FLT, (void *)dpnp_inv_ext_c<float, float>};
fmap[DPNPFuncName::DPNP_FN_INV_EXT][eft_DBL][eft_DBL] = {
eft_DBL, (void *)dpnp_inv_ext_c<double, double>};

Expand Down Expand Up @@ -1039,9 +1051,21 @@ void func_map_init_linalg_func(func_map_t &fmap)
// eft_C128, (void*)dpnp_qr_c<std::complex<double>, std::complex<double>>};

fmap[DPNPFuncName::DPNP_FN_QR_EXT][eft_INT][eft_INT] = {
eft_DBL, (void *)dpnp_qr_ext_c<int32_t, double>};
get_default_floating_type<>(),
(void *)dpnp_qr_ext_c<
int32_t, func_type_map_t::find_type<get_default_floating_type<>()>>,
get_default_floating_type<std::false_type>(),
(void *)dpnp_qr_ext_c<
int32_t, func_type_map_t::find_type<
get_default_floating_type<std::false_type>()>>};
fmap[DPNPFuncName::DPNP_FN_QR_EXT][eft_LNG][eft_LNG] = {
eft_DBL, (void *)dpnp_qr_ext_c<int64_t, double>};
get_default_floating_type<>(),
(void *)dpnp_qr_ext_c<
int64_t, func_type_map_t::find_type<get_default_floating_type<>()>>,
get_default_floating_type<std::false_type>(),
(void *)dpnp_qr_ext_c<
int64_t, func_type_map_t::find_type<
get_default_floating_type<std::false_type>()>>};
fmap[DPNPFuncName::DPNP_FN_QR_EXT][eft_FLT][eft_FLT] = {
eft_FLT, (void *)dpnp_qr_ext_c<float, float>};
fmap[DPNPFuncName::DPNP_FN_QR_EXT][eft_DBL][eft_DBL] = {
Expand All @@ -1062,9 +1086,29 @@ void func_map_init_linalg_func(func_map_t &fmap)
std::complex<double>, double>};

fmap[DPNPFuncName::DPNP_FN_SVD_EXT][eft_INT][eft_INT] = {
eft_DBL, (void *)dpnp_svd_ext_c<int32_t, double, double>};
get_default_floating_type<>(),
(void *)dpnp_svd_ext_c<
int32_t, func_type_map_t::find_type<get_default_floating_type<>()>,
func_type_map_t::find_type<get_default_floating_type<>()>>,
get_default_floating_type<std::false_type>(),
(void *)
dpnp_svd_ext_c<int32_t,
func_type_map_t::find_type<
get_default_floating_type<std::false_type>()>,
func_type_map_t::find_type<
get_default_floating_type<std::false_type>()>>};
fmap[DPNPFuncName::DPNP_FN_SVD_EXT][eft_LNG][eft_LNG] = {
eft_DBL, (void *)dpnp_svd_ext_c<int64_t, double, double>};
get_default_floating_type<>(),
(void *)dpnp_svd_ext_c<
int64_t, func_type_map_t::find_type<get_default_floating_type<>()>,
func_type_map_t::find_type<get_default_floating_type<>()>>,
get_default_floating_type<std::false_type>(),
(void *)
dpnp_svd_ext_c<int64_t,
func_type_map_t::find_type<
get_default_floating_type<std::false_type>()>,
func_type_map_t::find_type<
get_default_floating_type<std::false_type>()>>};
fmap[DPNPFuncName::DPNP_FN_SVD_EXT][eft_FLT][eft_FLT] = {
eft_FLT, (void *)dpnp_svd_ext_c<float, float, float>};
fmap[DPNPFuncName::DPNP_FN_SVD_EXT][eft_DBL][eft_DBL] = {
Expand Down
11 changes: 11 additions & 0 deletions dpnp/backend/src/dpnp_fptr.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -260,6 +260,17 @@ class dpnp_less_comp
}
};

/**
* A template function that determines the default floating-point type
* based on the value of the template parameter has_fp64.
*/
template <typename has_fp64 = std::true_type>
static constexpr DPNPFuncType get_default_floating_type()
{
return has_fp64::value ? DPNPFuncType::DPNP_FT_DOUBLE
: DPNPFuncType::DPNP_FT_FLOAT;
}

/**
* FPTR interface initialization functions
*/
Expand Down
2 changes: 2 additions & 0 deletions dpnp/dpnp_algo/dpnp_algo.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -336,6 +336,8 @@ cdef extern from "dpnp_iface_fptr.hpp":
struct DPNPFuncData:
DPNPFuncType return_type
void * ptr
DPNPFuncType return_type_no_fp64
void *ptr_no_fp64

DPNPFuncData get_dpnp_function_ptr(DPNPFuncName name, DPNPFuncType first_type, DPNPFuncType second_type) except +

Expand Down
Loading

0 comments on commit 771653b

Please sign in to comment.