Skip to content

Commit d9fa2fd

Browse files
[PHI] Support complex for paddle.linalg.lu_unpack , paddle.linalg.lu_solve (#74130)
* support complex for lu_solve * support lu_unpack complex * add test * add docs * fix lu_solve in DCU * fix PADDLE_WITH_HIP * fix lu_solve op test
1 parent 9775db5 commit d9fa2fd

File tree

16 files changed

+343
-23
lines changed

16 files changed

+343
-23
lines changed

paddle/phi/backends/dynload/cusolver.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,8 @@ CUSOLVER_ROUTINE_EACH(DECLARE_DYNAMIC_LOAD_CUSOLVER_WRAP);
9797
__macro(cusolverDnSgetrf); \
9898
__macro(cusolverDnSgetrs); \
9999
__macro(cusolverDnDgetrs); \
100+
__macro(cusolverDnCgetrs); \
101+
__macro(cusolverDnZgetrs); \
100102
__macro(cusolverDnDgetrf); \
101103
__macro(cusolverDnCgetrf); \
102104
__macro(cusolverDnZgetrf); \

paddle/phi/backends/dynload/lapack.h

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,24 @@ extern "C" void dgetrs_(char *trans,
5252
double *b,
5353
int *ldb,
5454
int *info);
55+
extern "C" void cgetrs_(char *trans,
56+
int *n,
57+
int *nrhs,
58+
std::complex<float> *a,
59+
int *lda,
60+
int *ipiv,
61+
std::complex<float> *b,
62+
int *ldb,
63+
int *info);
64+
extern "C" void zgetrs_(char *trans,
65+
int *n,
66+
int *nrhs,
67+
std::complex<double> *a,
68+
int *lda,
69+
int *ipiv,
70+
std::complex<double> *b,
71+
int *ldb,
72+
int *info);
5573

5674
// evd
5775
extern "C" void zheevd_(char *jobz,
@@ -396,6 +414,8 @@ extern void *lapack_dso_handle;
396414
__macro(zgetrf_); \
397415
__macro(sgetrs_); \
398416
__macro(dgetrs_); \
417+
__macro(cgetrs_); \
418+
__macro(zgetrs_); \
399419
__macro(zheevd_); \
400420
__macro(cheevd_); \
401421
__macro(dsyevd_); \

paddle/phi/backends/dynload/rocsolver.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,8 @@ extern void *rocsolver_dso_handle;
4747
__macro(rocsolver_zpotrs); \
4848
__macro(rocsolver_sgetrs); \
4949
__macro(rocsolver_dgetrs); \
50+
__macro(rocsolver_cgetrs); \
51+
__macro(rocsolver_zgetrs); \
5052
__macro(rocsolver_sgetrf); \
5153
__macro(rocsolver_dgetrf); \
5254
__macro(rocsolver_cgetrf); \

paddle/phi/kernels/cpu/lu_solve_grad_kernel.cc

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,5 +19,11 @@
1919
#include "paddle/phi/kernels/lu_solve_grad_kernel.h"
2020

2121
// Register the CPU backward kernel
22-
PD_REGISTER_KERNEL(
23-
lu_solve_grad, CPU, ALL_LAYOUT, phi::LuSolveGradKernel, float, double) {}
22+
PD_REGISTER_KERNEL(lu_solve_grad,
23+
CPU,
24+
ALL_LAYOUT,
25+
phi::LuSolveGradKernel,
26+
float,
27+
double,
28+
phi::dtype::complex<float>,
29+
phi::dtype::complex<double>) {}

paddle/phi/kernels/cpu/lu_solve_kernel.cc

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -78,5 +78,11 @@ void LuSolveKernel(const Context& dev_ctx,
7878
}
7979
} // namespace phi
8080

81-
PD_REGISTER_KERNEL(
82-
lu_solve, CPU, ALL_LAYOUT, phi::LuSolveKernel, float, double) {}
81+
PD_REGISTER_KERNEL(lu_solve,
82+
CPU,
83+
ALL_LAYOUT,
84+
phi::LuSolveKernel,
85+
float,
86+
double,
87+
phi::dtype::complex<float>,
88+
phi::dtype::complex<double>) {}

paddle/phi/kernels/cpu/lu_unpack_grad_kernel.cc

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,5 +18,11 @@
1818
#include "paddle/phi/kernels/impl/lu_unpack_grad_kernel_impl.h"
1919
#include "paddle/phi/kernels/lu_unpack_grad_kernel.h"
2020

21-
PD_REGISTER_KERNEL(
22-
lu_unpack_grad, CPU, ALL_LAYOUT, phi::LUUnpackGradKernel, float, double) {}
21+
PD_REGISTER_KERNEL(lu_unpack_grad,
22+
CPU,
23+
ALL_LAYOUT,
24+
phi::LUUnpackGradKernel,
25+
float,
26+
double,
27+
phi::dtype::complex<float>,
28+
phi::dtype::complex<double>) {}

paddle/phi/kernels/cpu/lu_unpack_kernel.cc

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,5 +18,11 @@
1818
#include "paddle/phi/kernels/impl/lu_unpack_kernel_impl.h"
1919
#include "paddle/phi/kernels/lu_unpack_kernel.h"
2020

21-
PD_REGISTER_KERNEL(
22-
lu_unpack, CPU, ALL_LAYOUT, phi::LUUnpackKernel, float, double) {}
21+
PD_REGISTER_KERNEL(lu_unpack,
22+
CPU,
23+
ALL_LAYOUT,
24+
phi::LUUnpackKernel,
25+
float,
26+
double,
27+
phi::dtype::complex<float>,
28+
phi::dtype::complex<double>) {}

paddle/phi/kernels/funcs/lapack/lapack_function.cc

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,48 @@ void lapackLuSolve<float>(char trans,
7979
dynload::sgetrs_(&trans, &n, &nrhs, a, &lda, ipiv, b, &ldb, info);
8080
}
8181

82+
template <>
83+
void lapackLuSolve<phi::dtype::complex<float>>(char trans,
84+
int n,
85+
int nrhs,
86+
phi::dtype::complex<float> *a,
87+
int lda,
88+
int *ipiv,
89+
phi::dtype::complex<float> *b,
90+
int ldb,
91+
int *info) {
92+
dynload::cgetrs_(&trans,
93+
&n,
94+
&nrhs,
95+
reinterpret_cast<std::complex<float> *>(a),
96+
&lda,
97+
ipiv,
98+
reinterpret_cast<std::complex<float> *>(b),
99+
&ldb,
100+
info);
101+
}
102+
103+
template <>
104+
void lapackLuSolve<phi::dtype::complex<double>>(char trans,
105+
int n,
106+
int nrhs,
107+
phi::dtype::complex<double> *a,
108+
int lda,
109+
int *ipiv,
110+
phi::dtype::complex<double> *b,
111+
int ldb,
112+
int *info) {
113+
dynload::zgetrs_(&trans,
114+
&n,
115+
&nrhs,
116+
reinterpret_cast<std::complex<double> *>(a),
117+
&lda,
118+
ipiv,
119+
reinterpret_cast<std::complex<double> *>(b),
120+
&ldb,
121+
info);
122+
}
123+
82124
// eigh
83125
template <>
84126
void lapackEigh<float>(char jobz,

paddle/phi/kernels/gpu/lu_grad_kernel.cu

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
// HIP not support cusolver in LUKernel
2323
PD_REGISTER_KERNEL(lu_grad, GPU, ALL_LAYOUT, phi::LUGradKernel, float, double) {
2424
}
25-
#else
25+
#else // PADDLE_WITH_CUDA
2626
PD_REGISTER_KERNEL(lu_grad,
2727
GPU,
2828
ALL_LAYOUT,
@@ -31,4 +31,4 @@ PD_REGISTER_KERNEL(lu_grad,
3131
double,
3232
phi::dtype::complex<float>,
3333
phi::dtype::complex<double>) {}
34-
#endif // PADDLE_WITH_HIP
34+
#endif

paddle/phi/kernels/gpu/lu_solve_grad_kernel.cu

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,5 +18,18 @@
1818
#include "paddle/phi/kernels/impl/lu_solve_grad_kernel_impl.h"
1919
#include "paddle/phi/kernels/lu_solve_grad_kernel.h"
2020

21+
#ifdef PADDLE_WITH_HIP
22+
// blas_impl.hip.h not support CUBlas<T>::TRSM for complex in
23+
// TriangularSolveKernel
2124
PD_REGISTER_KERNEL(
2225
lu_solve_grad, GPU, ALL_LAYOUT, phi::LuSolveGradKernel, float, double) {}
26+
#else // PADDLE_WITH_CUDA
27+
PD_REGISTER_KERNEL(lu_solve_grad,
28+
GPU,
29+
ALL_LAYOUT,
30+
phi::LuSolveGradKernel,
31+
float,
32+
double,
33+
phi::dtype::complex<float>,
34+
phi::dtype::complex<double>) {}
35+
#endif

0 commit comments

Comments
 (0)