diff --git a/paddle/fluid/platform/dynload/cublas.h b/paddle/fluid/platform/dynload/cublas.h index 496b253dff5b3..980b7cb35410b 100644 --- a/paddle/fluid/platform/dynload/cublas.h +++ b/paddle/fluid/platform/dynload/cublas.h @@ -80,8 +80,14 @@ namespace dynload { __macro(cublasSgetriBatched); \ __macro(cublasDgetrfBatched); \ __macro(cublasDgetriBatched); \ + __macro(cublasCgetrfBatched); \ + __macro(cublasCgetriBatched); \ + __macro(cublasZgetrfBatched); \ + __macro(cublasZgetriBatched); \ __macro(cublasSmatinvBatched); \ __macro(cublasDmatinvBatched); \ + __macro(cublasCmatinvBatched); \ + __macro(cublasZmatinvBatched); \ __macro(cublasSgetrsBatched); \ __macro(cublasDgetrsBatched); diff --git a/paddle/phi/backends/dynload/cublas.h b/paddle/phi/backends/dynload/cublas.h index 8053bbb6bd2ce..6da85283d6e71 100644 --- a/paddle/phi/backends/dynload/cublas.h +++ b/paddle/phi/backends/dynload/cublas.h @@ -94,8 +94,14 @@ extern void *cublas_dso_handle; __macro(cublasSgetriBatched); \ __macro(cublasDgetrfBatched); \ __macro(cublasDgetriBatched); \ + __macro(cublasCgetrfBatched); \ + __macro(cublasCgetriBatched); \ + __macro(cublasZgetrfBatched); \ + __macro(cublasZgetriBatched); \ __macro(cublasSmatinvBatched); \ __macro(cublasDmatinvBatched); \ + __macro(cublasCmatinvBatched); \ + __macro(cublasZmatinvBatched); \ __macro(cublasSgetrsBatched); \ __macro(cublasDgetrsBatched); diff --git a/paddle/phi/kernels/cpu/inverse_grad_kernel.cc b/paddle/phi/kernels/cpu/inverse_grad_kernel.cc index 97c10e69c8eab..5014cfd0f95c7 100644 --- a/paddle/phi/kernels/cpu/inverse_grad_kernel.cc +++ b/paddle/phi/kernels/cpu/inverse_grad_kernel.cc @@ -16,5 +16,11 @@ #include "paddle/phi/core/kernel_registry.h" -PD_REGISTER_KERNEL( - inverse_grad, CPU, ALL_LAYOUT, phi::InverseGradKernel, float, double) {} +PD_REGISTER_KERNEL(inverse_grad, + CPU, + ALL_LAYOUT, + phi::InverseGradKernel, + float, + double, + phi::dtype::complex, + phi::dtype::complex) {} diff --git a/paddle/phi/kernels/cpu/inverse_kernel.cc b/paddle/phi/kernels/cpu/inverse_kernel.cc index 4b21718eca3f2..6fecef6f888dc 100644 --- a/paddle/phi/kernels/cpu/inverse_kernel.cc +++ b/paddle/phi/kernels/cpu/inverse_kernel.cc @@ -16,5 +16,11 @@ #include "paddle/phi/core/kernel_registry.h" -PD_REGISTER_KERNEL( - inverse, CPU, ALL_LAYOUT, phi::InverseKernel, float, double) {} +PD_REGISTER_KERNEL(inverse, + CPU, + ALL_LAYOUT, + phi::InverseKernel, + float, + double, + phi::dtype::complex, + phi::dtype::complex) {} diff --git a/paddle/phi/kernels/funcs/blas/blas_impl.cu.h b/paddle/phi/kernels/funcs/blas/blas_impl.cu.h index 96b2128eee16c..a58b5998a6703 100644 --- a/paddle/phi/kernels/funcs/blas/blas_impl.cu.h +++ b/paddle/phi/kernels/funcs/blas/blas_impl.cu.h @@ -685,6 +685,63 @@ struct CUBlas> { ldb, batch_size)); } + + static void GETRF_BATCH(cublasHandle_t handle, + int n, + phi::dtype::complex **A, + int lda, + int *ipiv, + int *info, + int batch_size) { + PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasCgetrfBatched( + handle, + n, + reinterpret_cast(A), + lda, + ipiv, + info, + batch_size)); + } + + static void GETRI_BATCH(cublasHandle_t handle, + int n, + const phi::dtype::complex **A, + int lda, + const int *ipiv, + phi::dtype::complex **Ainv, + int ldc, + int *info, + int batch_size) { + PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasCgetriBatched( + handle, + n, + reinterpret_cast(A), + lda, + ipiv, + reinterpret_cast(Ainv), + ldc, + info, + batch_size)); + } + + static void MATINV_BATCH(cublasHandle_t handle, + int n, + const phi::dtype::complex **A, + int lda, + phi::dtype::complex **Ainv, + int lda_inv, + int *info, + int batch_size) { + PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasCmatinvBatched( + handle, + n, + reinterpret_cast(A), + lda, + reinterpret_cast(Ainv), + lda_inv, + info, + batch_size)); + } }; template <> @@ -923,6 +980,63 @@ struct CUBlas> { "cublasGemmEx is not supported on cuda <= 7.5")); #endif } + + static void GETRF_BATCH(cublasHandle_t handle, + int n, + phi::dtype::complex **A, + int lda, + int *ipiv, + int *info, + int batch_size) { + PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasZgetrfBatched( + handle, + n, + reinterpret_cast(A), + lda, + ipiv, + info, + batch_size)); + } + + static void GETRI_BATCH(cublasHandle_t handle, + int n, + const phi::dtype::complex **A, + int lda, + const int *ipiv, + phi::dtype::complex **Ainv, + int ldc, + int *info, + int batch_size) { + PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasZgetriBatched( + handle, + n, + reinterpret_cast(A), + lda, + ipiv, + reinterpret_cast(Ainv), + ldc, + info, + batch_size)); + } + + static void MATINV_BATCH(cublasHandle_t handle, + int n, + const phi::dtype::complex **A, + int lda, + phi::dtype::complex **Ainv, + int lda_inv, + int *info, + int batch_size) { + PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasZmatinvBatched( + handle, + n, + reinterpret_cast(A), + lda, + reinterpret_cast(Ainv), + lda_inv, + info, + batch_size)); + } }; template <> diff --git a/paddle/phi/kernels/funcs/matrix_inverse.cc b/paddle/phi/kernels/funcs/matrix_inverse.cc index c316970e6a560..5511507e7accc 100644 --- a/paddle/phi/kernels/funcs/matrix_inverse.cc +++ b/paddle/phi/kernels/funcs/matrix_inverse.cc @@ -28,6 +28,8 @@ void MatrixInverseFunctor::operator()(const Context& dev_ctx, template class MatrixInverseFunctor; template class MatrixInverseFunctor; +template class MatrixInverseFunctor>; +template class MatrixInverseFunctor>; } // namespace funcs } // namespace phi diff --git a/paddle/phi/kernels/funcs/matrix_inverse.cu b/paddle/phi/kernels/funcs/matrix_inverse.cu index c0ea7ad84c41b..f46dd714c9f55 100644 --- a/paddle/phi/kernels/funcs/matrix_inverse.cu +++ b/paddle/phi/kernels/funcs/matrix_inverse.cu @@ -131,6 +131,8 @@ void MatrixInverseFunctor::operator()(const Context& dev_ctx, template class MatrixInverseFunctor; template class MatrixInverseFunctor; +template class MatrixInverseFunctor>; +template class MatrixInverseFunctor>; } // namespace funcs } // namespace phi diff --git a/paddle/phi/kernels/funcs/matrix_inverse.h b/paddle/phi/kernels/funcs/matrix_inverse.h index f0cd265a54648..d45f7d8863a63 100644 --- a/paddle/phi/kernels/funcs/matrix_inverse.h +++ b/paddle/phi/kernels/funcs/matrix_inverse.h @@ -25,14 +25,69 @@ limitations under the License. */ namespace phi { namespace funcs { +template +struct MapMatrixInverseFunctor { + void operator()( + const Context& dev_ctx, const T* a_ptr, T* a_inv_ptr, int offset, int n) { + using Matrix = + Eigen::Matrix; + using EigenMatrixMap = Eigen::Map; + using ConstEigenMatrixMap = Eigen::Map; + + ConstEigenMatrixMap mat(a_ptr + offset, n, n); + EigenMatrixMap mat_inv(a_inv_ptr + offset, n, n); + Eigen::PartialPivLU lu; + lu.compute(mat); + + const T min_abs_pivot = lu.matrixLU().diagonal().cwiseAbs().minCoeff(); + PADDLE_ENFORCE_GT(min_abs_pivot, + static_cast(0), + errors::InvalidArgument("Input is not invertible.")); + mat_inv.noalias() = lu.inverse(); + } +}; + +template +struct MapMatrixInverseFunctor> { + void operator()(const Context& dev_ctx, + const phi::dtype::complex* a_ptr, + phi::dtype::complex* a_inv_ptr, + int offset, + int n) { + using Matrix = Eigen::Matrix, + Eigen::Dynamic, + Eigen::Dynamic, + Eigen::RowMajor>; + using EigenMatrixMap = Eigen::Map; + using ConstEigenMatrixMap = Eigen::Map; + std::complex* std_ptr = new std::complex[n * n]; + std::complex* std_inv_ptr = new std::complex[n * n]; + for (int i = 0; i < n * n; i++) { + *(std_ptr + i) = static_cast>(*(a_ptr + offset + i)); + } + ConstEigenMatrixMap mat(std_ptr, n, n); + EigenMatrixMap mat_inv(std_inv_ptr, n, n); + Eigen::PartialPivLU lu; + lu.compute(mat); + + const T min_abs_pivot = lu.matrixLU().diagonal().cwiseAbs().minCoeff(); + PADDLE_ENFORCE_NE(min_abs_pivot, + static_cast>(0), + errors::InvalidArgument("Input is not invertible.")); + mat_inv.noalias() = lu.inverse(); + for (int i = 0; i < n * n; i++) { + *(a_inv_ptr + offset + i) = + static_cast>(*(std_inv_ptr + i)); + } + delete[] std_ptr; + delete[] std_inv_ptr; + } +}; + template void ComputeInverseEigen(const Context& dev_ctx, const DenseTensor& a, DenseTensor* a_inv) { - using Matrix = - Eigen::Matrix; - using EigenMatrixMap = Eigen::Map; - using ConstEigenMatrixMap = Eigen::Map; const auto& mat_dims = a.dims(); const int rank = mat_dims.size(); int n = mat_dims[rank - 1]; @@ -41,17 +96,13 @@ void ComputeInverseEigen(const Context& dev_ctx, const T* a_ptr = a.data(); T* a_inv_ptr = dev_ctx.template Alloc(a_inv); + // Putting phi::dtype::complex into eigen::matrix has a problem, + // it's not going to get the right result, + // so we're going to convert it to std::complex and + // then we're going to put it into eigen::matrix. for (int i = 0; i < batch_size; ++i) { - ConstEigenMatrixMap mat(a_ptr + i * n * n, n, n); - EigenMatrixMap mat_inv(a_inv_ptr + i * n * n, n, n); - Eigen::PartialPivLU lu; - lu.compute(mat); - - const T min_abs_pivot = lu.matrixLU().diagonal().cwiseAbs().minCoeff(); - PADDLE_ENFORCE_GT(min_abs_pivot, - static_cast(0), - errors::InvalidArgument("Input is not invertible.")); - mat_inv.noalias() = lu.inverse(); + MapMatrixInverseFunctor functor; + functor(dev_ctx, a_ptr, a_inv_ptr, i * n * n, n); } } diff --git a/paddle/phi/kernels/gpu/inverse_grad_kernel.cu b/paddle/phi/kernels/gpu/inverse_grad_kernel.cu index 2fdc02934fedc..15c24719adfc3 100644 --- a/paddle/phi/kernels/gpu/inverse_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/inverse_grad_kernel.cu @@ -18,5 +18,11 @@ #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/kernels/impl/inverse_grad_kernel_impl.h" -PD_REGISTER_KERNEL( - inverse_grad, GPU, ALL_LAYOUT, phi::InverseGradKernel, float, double) {} +PD_REGISTER_KERNEL(inverse_grad, + GPU, + ALL_LAYOUT, + phi::InverseGradKernel, + float, + double, + phi::dtype::complex, + phi::dtype::complex) {} diff --git a/paddle/phi/kernels/gpu/inverse_kernel.cu b/paddle/phi/kernels/gpu/inverse_kernel.cu index 4c011337c6f8f..a9b4fcc763b0b 100644 --- a/paddle/phi/kernels/gpu/inverse_kernel.cu +++ b/paddle/phi/kernels/gpu/inverse_kernel.cu @@ -18,5 +18,11 @@ #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/kernels/impl/inverse_kernel_impl.h" -PD_REGISTER_KERNEL( - inverse, GPU, ALL_LAYOUT, phi::InverseKernel, float, double) {} +PD_REGISTER_KERNEL(inverse, + GPU, + ALL_LAYOUT, + phi::InverseKernel, + float, + double, + phi::dtype::complex, + phi::dtype::complex) {} diff --git a/paddle/phi/kernels/impl/inverse_grad_kernel_impl.h b/paddle/phi/kernels/impl/inverse_grad_kernel_impl.h index 26e2898bf73ff..aa23bddb5b979 100644 --- a/paddle/phi/kernels/impl/inverse_grad_kernel_impl.h +++ b/paddle/phi/kernels/impl/inverse_grad_kernel_impl.h @@ -18,6 +18,7 @@ #include "paddle/phi/backends/cpu/cpu_context.h" #include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/kernels/complex_kernel.h" #include "paddle/phi/kernels/funcs/blas/blas.h" #include "paddle/phi/kernels/funcs/matrix_inverse.h" @@ -37,15 +38,35 @@ void InverseGradKernel(const Context& dev_ctx, tmp_out.Resize(out.dims()); dev_ctx.template Alloc(&tmp_out); - auto mat_dim_a0 = - phi::funcs::CreateMatrixDescriptor(out_grad.dims(), 0, false); - auto mat_dim_b0 = phi::funcs::CreateMatrixDescriptor(out.dims(), 0, true); - blas.MatMul(out_grad, mat_dim_a0, out, mat_dim_b0, T(1), &tmp_out, T(0)); + if (IsComplexType(out.dtype())) { + DenseTensor out_conj; + out_conj.Resize(out.dims()); + dev_ctx.template Alloc(&out_conj); - auto mat_dim_a1 = phi::funcs::CreateMatrixDescriptor(out.dims(), 0, true); - auto mat_dim_b1 = - phi::funcs::CreateMatrixDescriptor(tmp_out.dims(), 0, false); - blas.MatMul(out, mat_dim_a1, tmp_out, mat_dim_b1, T(-1), in_grad, T(0)); + phi::ConjKernel(dev_ctx, out, &out_conj); + + auto mat_dim_a0 = + phi::funcs::CreateMatrixDescriptor(out_grad.dims(), 0, false); + auto mat_dim_b0 = phi::funcs::CreateMatrixDescriptor(out.dims(), 0, true); + blas.MatMul( + out_grad, mat_dim_a0, out_conj, mat_dim_b0, T(1), &tmp_out, T(0)); + + auto mat_dim_a1 = phi::funcs::CreateMatrixDescriptor(out.dims(), 0, true); + auto mat_dim_b1 = + phi::funcs::CreateMatrixDescriptor(tmp_out.dims(), 0, false); + blas.MatMul( + out_conj, mat_dim_a1, tmp_out, mat_dim_b1, T(-1), in_grad, T(0)); + } else { + auto mat_dim_a0 = + phi::funcs::CreateMatrixDescriptor(out_grad.dims(), 0, false); + auto mat_dim_b0 = phi::funcs::CreateMatrixDescriptor(out.dims(), 0, true); + blas.MatMul(out_grad, mat_dim_a0, out, mat_dim_b0, T(1), &tmp_out, T(0)); + + auto mat_dim_a1 = phi::funcs::CreateMatrixDescriptor(out.dims(), 0, true); + auto mat_dim_b1 = + phi::funcs::CreateMatrixDescriptor(tmp_out.dims(), 0, false); + blas.MatMul(out, mat_dim_a1, tmp_out, mat_dim_b1, T(-1), in_grad, T(0)); + } } } diff --git a/python/paddle/tensor/math.py b/python/paddle/tensor/math.py index d7d8669ff0c3b..88acb69e41d6f 100644 --- a/python/paddle/tensor/math.py +++ b/python/paddle/tensor/math.py @@ -2726,7 +2726,7 @@ def inverse(x, name=None): x (Tensor): The input tensor. The last two dimensions should be equal. When the number of dimensions is greater than 2, it is treated as batches of square matrix. The data - type can be float32 and float64. + type can be float32, float64, complex64, complex128. name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`. Returns: @@ -2751,7 +2751,12 @@ def inverse(x, name=None): else: def _check_input(x): - check_variable_and_dtype(x, 'x', ['float32', 'float64'], 'inverse') + check_variable_and_dtype( + x, + 'x', + ['float32', 'float64', 'complex64', 'complex128'], + 'inverse', + ) if len(x.shape) < 2: raise ValueError( "The input of inverse is expected to be a Tensor whose number " diff --git a/test/deprecated/legacy_test/test_inverse_op.py b/test/deprecated/legacy_test/test_inverse_op.py index 22810eecee07d..54f8466bd4d02 100644 --- a/test/deprecated/legacy_test/test_inverse_op.py +++ b/test/deprecated/legacy_test/test_inverse_op.py @@ -35,6 +35,12 @@ def setUp(self): np.random.seed(123) mat = np.random.random(self.matrix_shape).astype(self.dtype) + if self.dtype == 'complex64' or self.dtype == 'complex128': + mat = ( + np.random.random(self.matrix_shape) + + 1j * np.random.random(self.matrix_shape) + ).astype(self.dtype) + inverse = np.linalg.inv(mat) self.inputs = {'Input': mat} @@ -92,6 +98,26 @@ def config(self): self.python_api = paddle.tensor.math.inverse +class TestInverseOpComplex64(TestInverseOp): + def config(self): + self.matrix_shape = [10, 10] + self.dtype = "complex64" + self.python_api = paddle.tensor.math.inverse + + def test_grad(self): + self.check_grad(['Input'], 'Output', check_pir=True) + + +class TestInverseOpComplex128(TestInverseOp): + def config(self): + self.matrix_shape = [10, 10] + self.dtype = "complex128" + self.python_api = paddle.tensor.math.inverse + + def test_grad(self): + self.check_grad(['Input'], 'Output', check_pir=True) + + class TestInverseAPI(unittest.TestCase): def setUp(self): np.random.seed(123)