Skip to content

Commit

Permalink
rework implementation of diag, diagflat, vander, and ptp (IntelPython…
Browse files Browse the repository at this point in the history
…#1579)

* rework implementation of diag, diagflat, vander, and ptp

* address comments - first round

cherry-pick

* address comments - second round

* add tests for negative use cases to improve covergae

* fixed missing merge conflicts

* fix pre-commit
  • Loading branch information
vtavana authored Nov 20, 2023
1 parent 630dae2 commit 5bcf910
Show file tree
Hide file tree
Showing 16 changed files with 590 additions and 485 deletions.
69 changes: 32 additions & 37 deletions dpnp/backend/include/dpnp_iface_fptr.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -132,9 +132,7 @@ enum class DPNPFuncName : size_t
DPNP_FN_DET_EXT, /**< Used in numpy.linalg.det() impl, requires extra
parameters */
DPNP_FN_DIAG, /**< Used in numpy.diag() impl */
DPNP_FN_DIAG_EXT, /**< Used in numpy.diag() impl, requires extra parameters
*/
DPNP_FN_DIAG_INDICES, /**< Used in numpy.diag_indices() impl */
DPNP_FN_DIAG_INDICES, /**< Used in numpy.diag_indices() impl */
DPNP_FN_DIAG_INDICES_EXT, /**< Used in numpy.diag_indices() impl, requires
extra parameters */
DPNP_FN_DIAGONAL, /**< Used in numpy.diagonal() impl */
Expand Down Expand Up @@ -225,25 +223,24 @@ enum class DPNPFuncName : size_t
DPNP_FN_MODF_EXT, /**< Used in numpy.modf() impl, requires extra parameters
*/
DPNP_FN_MULTIPLY, /**< Used in numpy.multiply() impl */
DPNP_FN_MULTIPLY_EXT, /**< Used in numpy.multiply() impl, requires extra
parameters */
DPNP_FN_NANVAR, /**< Used in numpy.nanvar() impl */
DPNP_FN_NANVAR_EXT, /**< Used in numpy.nanvar() impl, requires extra
parameters */
DPNP_FN_NEGATIVE, /**< Used in numpy.negative() impl */
DPNP_FN_NONZERO, /**< Used in numpy.nonzero() impl */
DPNP_FN_ONES, /**< Used in numpy.ones() impl */
DPNP_FN_ONES_LIKE, /**< Used in numpy.ones_like() impl */
DPNP_FN_PARTITION, /**< Used in numpy.partition() impl */
DPNP_FN_PARTITION_EXT, /**< Used in numpy.partition() impl, requires extra
parameters */
DPNP_FN_PLACE, /**< Used in numpy.place() impl */
DPNP_FN_POWER, /**< Used in numpy.power() impl */
DPNP_FN_PROD, /**< Used in numpy.prod() impl */
DPNP_FN_PTP, /**< Used in numpy.ptp() impl */
DPNP_FN_PTP_EXT, /**< Used in numpy.ptp() impl, requires extra parameters */
DPNP_FN_PUT, /**< Used in numpy.put() impl */
DPNP_FN_PUT_ALONG_AXIS, /**< Used in numpy.put_along_axis() impl */
DPNP_FN_MULTIPLY_EXT, /**< Used in numpy.multiply() impl, requires extra
parameters */
DPNP_FN_NANVAR, /**< Used in numpy.nanvar() impl */
DPNP_FN_NANVAR_EXT, /**< Used in numpy.nanvar() impl, requires extra
parameters */
DPNP_FN_NEGATIVE, /**< Used in numpy.negative() impl */
DPNP_FN_NONZERO, /**< Used in numpy.nonzero() impl */
DPNP_FN_ONES, /**< Used in numpy.ones() impl */
DPNP_FN_ONES_LIKE, /**< Used in numpy.ones_like() impl */
DPNP_FN_PARTITION, /**< Used in numpy.partition() impl */
DPNP_FN_PARTITION_EXT, /**< Used in numpy.partition() impl, requires extra
parameters */
DPNP_FN_PLACE, /**< Used in numpy.place() impl */
DPNP_FN_POWER, /**< Used in numpy.power() impl */
DPNP_FN_PROD, /**< Used in numpy.prod() impl */
DPNP_FN_PTP, /**< Used in numpy.ptp() impl */
DPNP_FN_PUT, /**< Used in numpy.put() impl */
DPNP_FN_PUT_ALONG_AXIS, /**< Used in numpy.put_along_axis() impl */
DPNP_FN_PUT_ALONG_AXIS_EXT, /**< Used in numpy.put_along_axis() impl,
requires extra parameters */
DPNP_FN_QR, /**< Used in numpy.linalg.qr() impl */
Expand Down Expand Up @@ -401,21 +398,19 @@ enum class DPNPFuncName : size_t
DPNP_FN_TAKE, /**< Used in numpy.take() impl */
DPNP_FN_TAN, /**< Used in numpy.tan() impl */
DPNP_FN_TANH, /**< Used in numpy.tanh() impl */
DPNP_FN_TRANSPOSE, /**< Used in numpy.transpose() impl */
DPNP_FN_TRACE, /**< Used in numpy.trace() impl */
DPNP_FN_TRACE_EXT, /**< Used in numpy.trace() impl, requires extra
parameters */
DPNP_FN_TRAPZ, /**< Used in numpy.trapz() impl */
DPNP_FN_TRAPZ_EXT, /**< Used in numpy.trapz() impl, requires extra
parameters */
DPNP_FN_TRI, /**< Used in numpy.tri() impl */
DPNP_FN_TRIL, /**< Used in numpy.tril() impl */
DPNP_FN_TRIU, /**< Used in numpy.triu() impl */
DPNP_FN_TRUNC, /**< Used in numpy.trunc() impl */
DPNP_FN_VANDER, /**< Used in numpy.vander() impl */
DPNP_FN_VANDER_EXT, /**< Used in numpy.vander() impl, requires extra
parameters */
DPNP_FN_VAR, /**< Used in numpy.var() impl */
DPNP_FN_TRANSPOSE, /**< Used in numpy.transpose() impl */
DPNP_FN_TRACE, /**< Used in numpy.trace() impl */
DPNP_FN_TRACE_EXT, /**< Used in numpy.trace() impl, requires extra
parameters */
DPNP_FN_TRAPZ, /**< Used in numpy.trapz() impl */
DPNP_FN_TRAPZ_EXT, /**< Used in numpy.trapz() impl, requires extra
parameters */
DPNP_FN_TRI, /**< Used in numpy.tri() impl */
DPNP_FN_TRIL, /**< Used in numpy.tril() impl */
DPNP_FN_TRIU, /**< Used in numpy.triu() impl */
DPNP_FN_TRUNC, /**< Used in numpy.trunc() impl */
DPNP_FN_VANDER, /**< Used in numpy.vander() impl */
DPNP_FN_VAR, /**< Used in numpy.var() impl */
DPNP_FN_VAR_EXT, /**< Used in numpy.var() impl, requires extra parameters */
DPNP_FN_ZEROS, /**< Used in numpy.zeros() impl */
DPNP_FN_ZEROS_LIKE, /**< Used in numpy.zeros_like() impl */
Expand Down
74 changes: 0 additions & 74 deletions dpnp/backend/kernels/dpnp_krnl_arraycreation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -200,18 +200,6 @@ void (*dpnp_diag_default_c)(void *,
const size_t,
const size_t) = dpnp_diag_c<_DataType>;

template <typename _DataType>
DPCTLSyclEventRef (*dpnp_diag_ext_c)(DPCTLSyclQueueRef,
void *,
void *,
const int,
shape_elem_type *,
shape_elem_type *,
const size_t,
const size_t,
const DPCTLEventVectorRef) =
dpnp_diag_c<_DataType>;

template <typename _DataType>
DPCTLSyclEventRef dpnp_eye_c(DPCTLSyclQueueRef q_ref,
void *result1,
Expand Down Expand Up @@ -569,23 +557,6 @@ void (*dpnp_ptp_default_c)(void *,
const shape_elem_type *,
const size_t) = dpnp_ptp_c<_DataType>;

template <typename _DataType>
DPCTLSyclEventRef (*dpnp_ptp_ext_c)(DPCTLSyclQueueRef,
void *,
const size_t,
const size_t,
const shape_elem_type *,
const shape_elem_type *,
const void *,
const size_t,
const size_t,
const shape_elem_type *,
const shape_elem_type *,
const shape_elem_type *,
const size_t,
const DPCTLEventVectorRef) =
dpnp_ptp_c<_DataType>;

template <typename _DataType_input, typename _DataType_output>
DPCTLSyclEventRef dpnp_vander_c(DPCTLSyclQueueRef q_ref,
const void *array1_in,
Expand Down Expand Up @@ -673,16 +644,6 @@ void (*dpnp_vander_default_c)(const void *,
const int) =
dpnp_vander_c<_DataType_input, _DataType_output>;

template <typename _DataType_input, typename _DataType_output>
DPCTLSyclEventRef (*dpnp_vander_ext_c)(DPCTLSyclQueueRef,
const void *,
void *,
const size_t,
const size_t,
const int,
const DPCTLEventVectorRef) =
dpnp_vander_c<_DataType_input, _DataType_output>;

template <typename _DataType, typename _ResultType>
class dpnp_trace_c_kernel;

Expand Down Expand Up @@ -1192,15 +1153,6 @@ void func_map_init_arraycreation(func_map_t &fmap)
fmap[DPNPFuncName::DPNP_FN_DIAG][eft_DBL][eft_DBL] = {
eft_DBL, (void *)dpnp_diag_default_c<double>};

fmap[DPNPFuncName::DPNP_FN_DIAG_EXT][eft_INT][eft_INT] = {
eft_INT, (void *)dpnp_diag_ext_c<int32_t>};
fmap[DPNPFuncName::DPNP_FN_DIAG_EXT][eft_LNG][eft_LNG] = {
eft_LNG, (void *)dpnp_diag_ext_c<int64_t>};
fmap[DPNPFuncName::DPNP_FN_DIAG_EXT][eft_FLT][eft_FLT] = {
eft_FLT, (void *)dpnp_diag_ext_c<float>};
fmap[DPNPFuncName::DPNP_FN_DIAG_EXT][eft_DBL][eft_DBL] = {
eft_DBL, (void *)dpnp_diag_ext_c<double>};

fmap[DPNPFuncName::DPNP_FN_EYE][eft_INT][eft_INT] = {
eft_INT, (void *)dpnp_eye_default_c<int32_t>};
fmap[DPNPFuncName::DPNP_FN_EYE][eft_LNG][eft_LNG] = {
Expand Down Expand Up @@ -1284,15 +1236,6 @@ void func_map_init_arraycreation(func_map_t &fmap)
fmap[DPNPFuncName::DPNP_FN_PTP][eft_DBL][eft_DBL] = {
eft_DBL, (void *)dpnp_ptp_default_c<double>};

fmap[DPNPFuncName::DPNP_FN_PTP_EXT][eft_INT][eft_INT] = {
eft_INT, (void *)dpnp_ptp_ext_c<int32_t>};
fmap[DPNPFuncName::DPNP_FN_PTP_EXT][eft_LNG][eft_LNG] = {
eft_LNG, (void *)dpnp_ptp_ext_c<int64_t>};
fmap[DPNPFuncName::DPNP_FN_PTP_EXT][eft_FLT][eft_FLT] = {
eft_FLT, (void *)dpnp_ptp_ext_c<float>};
fmap[DPNPFuncName::DPNP_FN_PTP_EXT][eft_DBL][eft_DBL] = {
eft_DBL, (void *)dpnp_ptp_ext_c<double>};

fmap[DPNPFuncName::DPNP_FN_VANDER][eft_INT][eft_INT] = {
eft_LNG, (void *)dpnp_vander_default_c<int32_t, int64_t>};
fmap[DPNPFuncName::DPNP_FN_VANDER][eft_LNG][eft_LNG] = {
Expand All @@ -1308,23 +1251,6 @@ void func_map_init_arraycreation(func_map_t &fmap)
(void *)
dpnp_vander_default_c<std::complex<double>, std::complex<double>>};

fmap[DPNPFuncName::DPNP_FN_VANDER_EXT][eft_INT][eft_INT] = {
eft_LNG, (void *)dpnp_vander_ext_c<int32_t, int64_t>};
fmap[DPNPFuncName::DPNP_FN_VANDER_EXT][eft_LNG][eft_LNG] = {
eft_LNG, (void *)dpnp_vander_ext_c<int64_t, int64_t>};
fmap[DPNPFuncName::DPNP_FN_VANDER_EXT][eft_FLT][eft_FLT] = {
eft_FLT, (void *)dpnp_vander_ext_c<float, float>};
fmap[DPNPFuncName::DPNP_FN_VANDER_EXT][eft_DBL][eft_DBL] = {
eft_DBL, (void *)dpnp_vander_ext_c<double, double>};
fmap[DPNPFuncName::DPNP_FN_VANDER_EXT][eft_BLN][eft_BLN] = {
eft_LNG, (void *)dpnp_vander_ext_c<bool, int64_t>};
fmap[DPNPFuncName::DPNP_FN_VANDER_EXT][eft_C64][eft_C64] = {
eft_C64,
(void *)dpnp_vander_ext_c<std::complex<float>, std::complex<float>>};
fmap[DPNPFuncName::DPNP_FN_VANDER_EXT][eft_C128][eft_C128] = {
eft_C128,
(void *)dpnp_vander_ext_c<std::complex<double>, std::complex<double>>};

fmap[DPNPFuncName::DPNP_FN_TRACE][eft_INT][eft_INT] = {
eft_INT, (void *)dpnp_trace_default_c<int32_t, int32_t>};
fmap[DPNPFuncName::DPNP_FN_TRACE][eft_LNG][eft_INT] = {
Expand Down
6 changes: 0 additions & 6 deletions dpnp/dpnp_algo/dpnp_algo.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,6 @@ cdef extern from "dpnp_iface_fptr.hpp" namespace "DPNPFuncName": # need this na
DPNP_FN_DEGREES_EXT
DPNP_FN_DET
DPNP_FN_DET_EXT
DPNP_FN_DIAG
DPNP_FN_DIAG_EXT
DPNP_FN_DIAG_INDICES
DPNP_FN_DIAG_INDICES_EXT
DPNP_FN_DIAGONAL
Expand Down Expand Up @@ -120,8 +118,6 @@ cdef extern from "dpnp_iface_fptr.hpp" namespace "DPNPFuncName": # need this na
DPNP_FN_PARTITION
DPNP_FN_PARTITION_EXT
DPNP_FN_PLACE
DPNP_FN_PTP
DPNP_FN_PTP_EXT
DPNP_FN_QR
DPNP_FN_QR_EXT
DPNP_FN_RADIANS
Expand Down Expand Up @@ -218,8 +214,6 @@ cdef extern from "dpnp_iface_fptr.hpp" namespace "DPNPFuncName": # need this na
DPNP_FN_TRIL_EXT
DPNP_FN_TRIU
DPNP_FN_TRIU_EXT
DPNP_FN_VANDER
DPNP_FN_VANDER_EXT
DPNP_FN_VAR
DPNP_FN_VAR_EXT
DPNP_FN_ZEROS
Expand Down
Loading

0 comments on commit 5bcf910

Please sign in to comment.